heavyball 0.21.3__py3-none-any.whl → 0.21.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +1 -1
- heavyball/cached_psgd_kron.py +1 -1
- heavyball/delayed_psgd.py +4 -4
- heavyball/p_adam.py +2 -2
- heavyball/utils.py +4 -4
- {heavyball-0.21.3.dist-info → heavyball-0.21.4.dist-info}/METADATA +1 -1
- {heavyball-0.21.3.dist-info → heavyball-0.21.4.dist-info}/RECORD +10 -10
- {heavyball-0.21.3.dist-info → heavyball-0.21.4.dist-info}/LICENSE +0 -0
- {heavyball-0.21.3.dist-info → heavyball-0.21.4.dist-info}/WHEEL +0 -0
- {heavyball-0.21.3.dist-info → heavyball-0.21.4.dist-info}/top_level.txt +0 -0
@@ -120,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
120
120
|
q_orig = Q_list.pop(0)
|
121
121
|
ea = exp_avg_list.pop(0)
|
122
122
|
|
123
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
|
123
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
|
124
124
|
|
125
125
|
if should_update:
|
126
126
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -128,4 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
128
|
else:
|
129
129
|
torch.mul(q_.conj(), q_, out=c_)
|
130
130
|
|
131
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
|
131
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
|
heavyball/delayed_psgd.py
CHANGED
@@ -12,9 +12,9 @@ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust
|
|
12
12
|
|
13
13
|
|
14
14
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
15
|
-
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr,
|
15
|
+
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_deca, clip_fn):
|
16
16
|
new = psgd_precond_grad(q, exprs, ea)
|
17
|
-
update_param_([p],
|
17
|
+
update_param_([p], clip_fn([new]), lr, weight_decay)
|
18
18
|
|
19
19
|
|
20
20
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -62,7 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
62
62
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
63
63
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
64
64
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
65
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
65
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
66
66
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
67
67
|
|
68
68
|
def _step(self, group):
|
@@ -111,7 +111,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
111
111
|
q_orig = Q_list.pop(0)
|
112
112
|
ea = exp_avg_list.pop(0)
|
113
113
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
|
-
_compilable_psgd_precond_grad_(q,
|
114
|
+
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn)
|
115
115
|
if should_update:
|
116
116
|
q32 = [promote(q_) for q_ in q]
|
117
117
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
heavyball/p_adam.py
CHANGED
@@ -6,7 +6,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
|
-
from heavyball.utils import triu_to_line, line_to_triu, identity
|
9
|
+
from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
11
|
split_p_and_g_in_group, promote
|
12
12
|
|
@@ -100,7 +100,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
100
100
|
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
101
101
|
q32 = [promote(qq_) for qq_ in q_]
|
102
102
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
103
|
-
|
103
|
+
stochastic_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
104
104
|
|
105
105
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
106
106
|
|
heavyball/utils.py
CHANGED
@@ -966,17 +966,17 @@ class PSGDBase(StatefulOptimizer):
|
|
966
966
|
|
967
967
|
|
968
968
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
969
|
-
def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
|
969
|
+
def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn):
|
970
970
|
md = min_dtype(cached_q + [ea])
|
971
971
|
new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
972
|
-
update_param_([param],
|
972
|
+
update_param_([param], clip_fn([new]), lr, weight_decay)
|
973
973
|
|
974
974
|
|
975
975
|
def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
|
976
|
-
weight_decay: float):
|
976
|
+
weight_decay: float, clip_fn):
|
977
977
|
if isinstance(lr, float):
|
978
978
|
lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
|
979
|
-
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
|
979
|
+
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn)
|
980
980
|
|
981
981
|
|
982
982
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1,13 +1,13 @@
|
|
1
1
|
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=
|
4
|
-
heavyball/delayed_psgd.py,sha256=
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=Nyxl-G-o6greKwDN-vLiw5W02GXO2LRvknc0OzvzFnE,6674
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=HzD6se0AYb-W5hpydUxcR9uqrpe_54PBwgL1VWX3DHU,6592
|
4
|
+
heavyball/delayed_psgd.py,sha256=T4OzqGgiycbxuTYJyMSCEI3PBWqPdC6g29KgrNg-JHg,5984
|
5
5
|
heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
|
6
6
|
heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
|
7
7
|
heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
|
9
9
|
heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
|
10
|
-
heavyball/p_adam.py,sha256=
|
10
|
+
heavyball/p_adam.py,sha256=Xyxsavwtw-t0OyTHitYQXZSmF9UJlMDzDAURge-MbbQ,6047
|
11
11
|
heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
|
12
12
|
heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
|
13
13
|
heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
|
|
16
16
|
heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
|
17
17
|
heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.21.
|
21
|
-
heavyball-0.21.
|
22
|
-
heavyball-0.21.
|
23
|
-
heavyball-0.21.
|
24
|
-
heavyball-0.21.
|
19
|
+
heavyball/utils.py,sha256=L9OaPVWBGl6hyXbt9cYsq4QRUsX4tUIgokeGuywHT84,37209
|
20
|
+
heavyball-0.21.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.21.4.dist-info/METADATA,sha256=ApqLqW-YQ9iqvfRMQu-Y9CEo0DMySpMxUdpV57Xilb4,11926
|
22
|
+
heavyball-0.21.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.21.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.21.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|