heavyball 0.21.0__tar.gz → 0.21.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.21.0 → heavyball-0.21.1}/PKG-INFO +1 -1
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/cached_delayed_psgd_kron.py +5 -5
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/cached_psgd_kron.py +2 -3
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/delayed_psgd.py +3 -3
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/utils.py +14 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.21.0 → heavyball-0.21.1}/setup.py +1 -1
- {heavyball-0.21.0 → heavyball-0.21.1}/LICENSE +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/README.md +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/__init__.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/p_adam.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/setup.cfg +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_bf16_params.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_bf16_q.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_bf16_storage.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_closure.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_ema.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_foreach.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_memory.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_merge.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_no_grad.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_psgd.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_soap.py +0 -0
- {heavyball-0.21.0 → heavyball-0.21.1}/test/test_stochastic_updates.py +0 -0
@@ -7,9 +7,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
+
from heavyball.utils import min_dtype, precond_grad_cached_
|
10
11
|
|
11
12
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
14
|
|
14
15
|
|
15
16
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -59,7 +60,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
59
60
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
60
61
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
61
62
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
62
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
63
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
64
|
+
storage_dtype=storage_dtype)
|
63
65
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
66
|
|
65
67
|
def _step(self, group):
|
@@ -118,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
118
120
|
q_orig = Q_list.pop(0)
|
119
121
|
ea = exp_avg_list.pop(0)
|
120
122
|
|
121
|
-
|
123
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
|
122
124
|
|
123
125
|
if should_update:
|
124
126
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
@@ -130,5 +132,3 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
130
132
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
131
133
|
else:
|
132
134
|
torch.mul(q_.conj(), q_, out=c_)
|
133
|
-
|
134
|
-
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -128,5 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
128
|
else:
|
129
129
|
torch.mul(q_.conj(), q_, out=c_)
|
130
130
|
|
131
|
-
|
132
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
131
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
|
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import stochastic_lerp_, beta_debias
|
9
|
+
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
11
|
split_p_and_g_in_group, triu_to_line, line_to_triu, promote
|
12
12
|
|
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
42
|
-
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32',
|
42
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -105,8 +105,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
105
105
|
ea = exp_avg_list.pop(0)
|
106
106
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
107
107
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
108
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
108
109
|
if should_update:
|
109
110
|
q32 = [promote(q_) for q_ in q]
|
110
111
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
111
112
|
store_triu_as_line)
|
112
|
-
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
@@ -965,6 +965,20 @@ class PSGDBase(StatefulOptimizer):
|
|
965
965
|
psgd_balance_Q(q)
|
966
966
|
|
967
967
|
|
968
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
969
|
+
def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
|
970
|
+
md = min_dtype(cached_q + [ea])
|
971
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
972
|
+
update_param_([param], self.clip_fn([new]), lr, weight_decay)
|
973
|
+
|
974
|
+
|
975
|
+
def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
|
976
|
+
weight_decay: float):
|
977
|
+
if isinstance(lr, float):
|
978
|
+
lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
|
979
|
+
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
|
980
|
+
|
981
|
+
|
968
982
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
969
983
|
"""Anneal preconditioner update probability during beginning of training.
|
970
984
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|