heavyball 0.18.2__tar.gz → 0.18.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.2 → heavyball-0.18.4}/PKG-INFO +1 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/cached_delayed_psgd_kron.py +2 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/cached_psgd_kron.py +4 -3
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/delayed_psgd.py +2 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/p_adam.py +2 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/psgd_kron.py +2 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/pure_psgd.py +2 -2
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/utils.py +12 -13
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/setup.py +1 -1
- {heavyball-0.18.2 → heavyball-0.18.4}/LICENSE +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/README.md +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/__init__.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/setup.cfg +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_closure.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_foreach.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_memory.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_merge.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_no_grad.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_psgd.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_soap.py +0 -0
- {heavyball-0.18.2 → heavyball-0.18.4}/test/test_stochastic_updates.py +0 -0
@@ -62,6 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
62
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
63
|
|
64
64
|
def _step(self, group):
|
65
|
+
should_update = self.should_update(group)
|
65
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
66
67
|
precond_init_scale = group['precond_init_scale']
|
67
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -115,7 +116,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
115
116
|
|
116
117
|
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
117
118
|
|
118
|
-
if
|
119
|
+
if should_update:
|
119
120
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
120
121
|
q32 = [promote(q_) for q_ in q]
|
121
122
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
@@ -8,8 +8,8 @@ from typing import Optional
|
|
8
8
|
|
9
9
|
import torch
|
10
10
|
|
11
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
+
line_to_triu, triu_to_line, set_, einsum_base, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -71,6 +71,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
71
71
|
beta = group['beta']
|
72
72
|
store_triu_as_line = group['store_triu_as_line']
|
73
73
|
q_dtype = getattr(torch, group['q_dtype'])
|
74
|
+
should_update = self.should_update(group)
|
74
75
|
|
75
76
|
vals = []
|
76
77
|
|
@@ -111,7 +112,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
111
112
|
q_orig = Q_list.pop(0)
|
112
113
|
ea = exp_avg_list.pop(0)
|
113
114
|
|
114
|
-
if
|
115
|
+
if should_update:
|
115
116
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
116
117
|
q32 = [promote(q_) for q_ in q]
|
117
118
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
@@ -62,6 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
62
62
|
|
63
63
|
|
64
64
|
def _step(self, group):
|
65
|
+
should_update = self.should_update(group)
|
65
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
66
67
|
precond_init_scale = group['precond_init_scale']
|
67
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -103,7 +104,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
103
104
|
ea = exp_avg_list.pop(0)
|
104
105
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
105
106
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
106
|
-
if
|
107
|
+
if should_update:
|
107
108
|
q32 = [promote(q_) for q_ in q]
|
108
109
|
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
109
110
|
set_(g, new)
|
@@ -61,6 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
61
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
62
|
|
63
63
|
def _step(self, group):
|
64
|
+
should_update = self.should_update(group)
|
64
65
|
precond_init_scale = group['precond_init_scale']
|
65
66
|
max_size_triangular = group['max_size_triangular']
|
66
67
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -94,7 +95,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
94
95
|
group["step"] += 1
|
95
96
|
|
96
97
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
97
|
-
if
|
98
|
+
if should_update:
|
98
99
|
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
99
100
|
q32 = [promote(qq_) for qq_ in q_]
|
100
101
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
@@ -60,6 +60,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
60
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
61
|
|
62
62
|
def _step(self, group):
|
63
|
+
should_update = self.should_update(group)
|
63
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
64
65
|
precond_init_scale = group['precond_init_scale']
|
65
66
|
max_size_triangular = group['max_size_triangular']
|
@@ -101,7 +102,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
101
102
|
ea = exp_avg_list.pop(0)
|
102
103
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
103
104
|
|
104
|
-
if
|
105
|
+
if should_update:
|
105
106
|
q32 = [promote(q_) for q_ in q]
|
106
107
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
107
108
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
@@ -57,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
57
57
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
58
|
|
59
59
|
def _step(self, group):
|
60
|
-
|
60
|
+
should_update = self.should_update(group)
|
61
61
|
precond_init_scale = group['precond_init_scale']
|
62
62
|
max_size_triangular = group['max_size_triangular']
|
63
63
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -93,7 +93,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
93
93
|
q_orig = Q_list.pop(0)
|
94
94
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
95
95
|
|
96
|
-
if
|
96
|
+
if group:
|
97
97
|
q32 = [promote(q_) for q_ in q]
|
98
98
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
@@ -335,9 +335,7 @@ def promote(x):
|
|
335
335
|
def min_dtype(xs: List[torch.Tensor]):
|
336
336
|
dtypes = [x.dtype for x in xs]
|
337
337
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
338
|
-
if all(d
|
339
|
-
return d
|
340
|
-
if all(d in (x, torch.float32, torch.float64) for x in dtypes):
|
338
|
+
if all(x in (d, torch.float32, torch.float64) for x in dtypes):
|
341
339
|
return d
|
342
340
|
return torch.float32
|
343
341
|
|
@@ -481,7 +479,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
481
479
|
copy_stochastic_(t, s)
|
482
480
|
|
483
481
|
|
484
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
482
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
485
483
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
486
484
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
487
485
|
# create a random 16 bit integer
|
@@ -817,20 +815,21 @@ class PSGDBase(StatefulOptimizer):
|
|
817
815
|
|
818
816
|
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
819
817
|
store_triu_as_line=False):
|
820
|
-
|
818
|
+
if original_q:
|
819
|
+
if store_triu_as_line:
|
820
|
+
update_fn = update_triu_
|
821
|
+
else:
|
822
|
+
update_fn = copy_stochastic_list_
|
823
|
+
else:
|
824
|
+
update_fn = lambda x, y: None
|
825
|
+
for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
|
821
826
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
827
|
+
update_fn(oq, Q)
|
822
828
|
|
823
|
-
for g, q in zip(grad_list, q_list):
|
829
|
+
for g, q in zip(grad_list, original_q if original_q else q_list):
|
824
830
|
if g.dim() > 1:
|
825
831
|
psgd_balance_Q(q)
|
826
832
|
|
827
|
-
if original_q:
|
828
|
-
for q in q_list:
|
829
|
-
if store_triu_as_line:
|
830
|
-
update_triu_(original_q[i], Q)
|
831
|
-
else:
|
832
|
-
copy_stochastic_list_(original_q[i], Q)
|
833
|
-
|
834
833
|
|
835
834
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
836
835
|
"""Anneal preconditioner update probability during beginning of training.
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|