heavyball 0.21.0__tar.gz → 0.21.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.21.0 → heavyball-0.21.2}/PKG-INFO +1 -1
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/cached_delayed_psgd_kron.py +5 -5
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/cached_psgd_kron.py +2 -3
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/delayed_psgd.py +10 -4
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/utils.py +14 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.21.0 → heavyball-0.21.2}/setup.py +1 -1
- {heavyball-0.21.0 → heavyball-0.21.2}/LICENSE +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/README.md +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/__init__.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/p_adam.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/setup.cfg +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_bf16_params.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_bf16_q.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_bf16_storage.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_closure.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_ema.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_foreach.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_memory.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_merge.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_no_grad.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_psgd.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.2}/test/test_stochastic_updates.py +0 -0
@@ -7,9 +7,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
+
from heavyball.utils import min_dtype, precond_grad_cached_
|
10
11
|
|
11
12
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
14
|
|
14
15
|
|
15
16
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -59,7 +60,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
59
60
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
60
61
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
61
62
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
62
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
63
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
64
|
+
storage_dtype=storage_dtype)
|
63
65
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
66
|
|
65
67
|
def _step(self, group):
|
@@ -118,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
118
120
|
q_orig = Q_list.pop(0)
|
119
121
|
ea = exp_avg_list.pop(0)
|
120
122
|
|
121
|
-
|
123
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
|
122
124
|
|
123
125
|
if should_update:
|
124
126
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
@@ -130,5 +132,3 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
130
132
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
131
133
|
else:
|
132
134
|
torch.mul(q_.conj(), q_, out=c_)
|
133
|
-
|
134
|
-
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -128,5 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
128
|
else:
|
129
129
|
torch.mul(q_.conj(), q_, out=c_)
|
130
130
|
|
131
|
-
|
132
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
131
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
|
@@ -5,12 +5,18 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import stochastic_lerp_, beta_debias
|
9
|
+
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
11
|
split_p_and_g_in_group, triu_to_line, line_to_triu, promote
|
12
12
|
|
13
13
|
|
14
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
15
|
+
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay):
|
16
|
+
new = psgd_precond_grad(q, exprs, ea)
|
17
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
18
|
+
|
19
|
+
|
14
20
|
class ForeachDelayedPSGD(PSGDBase):
|
15
21
|
"""
|
16
22
|
Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP)
|
@@ -39,7 +45,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
45
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
46
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
47
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
42
|
-
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32',
|
48
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
|
43
49
|
# expert parameters
|
44
50
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
51
|
if not 0.0 <= lr:
|
@@ -98,15 +104,15 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
98
104
|
stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
|
99
105
|
|
100
106
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
107
|
+
lr = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(lr)
|
101
108
|
|
102
109
|
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
103
110
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
104
111
|
q_orig = Q_list.pop(0)
|
105
112
|
ea = exp_avg_list.pop(0)
|
106
113
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
107
|
-
|
114
|
+
_compilable_psgd_precond_grad_(q, state["exprs"], ea, p, lr, weight_decay)
|
108
115
|
if should_update:
|
109
116
|
q32 = [promote(q_) for q_ in q]
|
110
117
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
111
118
|
store_triu_as_line)
|
112
|
-
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
@@ -965,6 +965,20 @@ class PSGDBase(StatefulOptimizer):
|
|
965
965
|
psgd_balance_Q(q)
|
966
966
|
|
967
967
|
|
968
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
969
|
+
def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
|
970
|
+
md = min_dtype(cached_q + [ea])
|
971
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
972
|
+
update_param_([param], self.clip_fn([new]), lr, weight_decay)
|
973
|
+
|
974
|
+
|
975
|
+
def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
|
976
|
+
weight_decay: float):
|
977
|
+
if isinstance(lr, float):
|
978
|
+
lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
|
979
|
+
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
|
980
|
+
|
981
|
+
|
968
982
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
969
983
|
"""Anneal preconditioner update probability during beginning of training.
|
970
984
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|