heavyball 0.17.2__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +6 -19
- heavyball/cached_psgd_kron.py +4 -17
- heavyball/delayed_psgd.py +4 -17
- heavyball/p_adam.py +9 -23
- heavyball/psgd_kron.py +4 -17
- heavyball/pure_psgd.py +6 -19
- heavyball/utils.py +30 -13
- {heavyball-0.17.2.dist-info → heavyball-0.18.0.dist-info}/METADATA +1 -1
- {heavyball-0.17.2.dist-info → heavyball-0.18.0.dist-info}/RECORD +12 -12
- {heavyball-0.17.2.dist-info → heavyball-0.18.0.dist-info}/LICENSE +0 -0
- {heavyball-0.17.2.dist-info → heavyball-0.18.0.dist-info}/WHEEL +0 -0
- {heavyball-0.17.2.dist-info → heavyball-0.18.0.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
42
42
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
43
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
44
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
-
foreach: bool = True, q_dtype='float32'):
|
45
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
|
46
46
|
if not 0.0 <= lr:
|
47
47
|
raise ValueError(f"Invalid learning rate: {lr}")
|
48
48
|
if not 0.0 <= beta < 1.0:
|
@@ -50,12 +50,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
50
50
|
if not 0.0 <= weight_decay:
|
51
51
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
52
52
|
|
53
|
-
if preconditioner_update_probability is None:
|
54
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
55
53
|
if clip_fn is None:
|
56
54
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
57
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
58
|
-
self.clip_fn = clip_fn
|
59
55
|
|
60
56
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
61
57
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
@@ -63,20 +59,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
63
59
|
# precond lr hardcoded to 0.1
|
64
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
65
61
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
66
|
-
store_triu_as_line=store_triu_as_line,
|
67
|
-
|
68
|
-
super().__init__(params, defaults, foreach)
|
62
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
63
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
69
64
|
|
70
|
-
self._prob_step = 0
|
71
65
|
|
72
66
|
def _step(self, group):
|
73
|
-
# update preconditioners all together
|
74
|
-
update_prob = self.preconditioner_update_probability
|
75
|
-
if callable(update_prob):
|
76
|
-
update_prob = update_prob(self._prob_step)
|
77
|
-
do_update = self.rng.random() < update_prob
|
78
|
-
self._prob_step += 1
|
79
|
-
|
80
67
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
81
68
|
precond_init_scale = group['precond_init_scale']
|
82
69
|
max_size_triangular = group['max_size_triangular']
|
@@ -128,11 +115,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
128
115
|
q_orig = Q_list.pop(0)
|
129
116
|
ea = exp_avg_list.pop(0)
|
130
117
|
|
131
|
-
if
|
118
|
+
if self.should_update(group):
|
132
119
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
133
120
|
q32 = [promote(q_) for q_ in q]
|
134
|
-
self.
|
135
|
-
|
121
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
122
|
+
store_triu_as_line)
|
136
123
|
for c_, q_ in zip(cached_q, q):
|
137
124
|
if q_.ndim == 2:
|
138
125
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -40,7 +40,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
40
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
41
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
42
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
-
foreach: bool = True, q_dtype='float32'):
|
43
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
|
44
44
|
if not 0.0 <= lr:
|
45
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
46
46
|
if not 0.0 <= beta < 1.0:
|
@@ -48,12 +48,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
48
48
|
if not 0.0 <= weight_decay:
|
49
49
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
50
50
|
|
51
|
-
if preconditioner_update_probability is None:
|
52
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
53
51
|
if clip_fn is None:
|
54
52
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
55
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
56
|
-
self.clip_fn = clip_fn
|
57
53
|
|
58
54
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
59
55
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
@@ -63,18 +59,10 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
63
59
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
64
60
|
store_triu_as_line=store_triu_as_line,
|
65
61
|
q_dtype=q_dtype)
|
66
|
-
super().__init__(params, defaults, foreach)
|
62
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
67
63
|
|
68
|
-
self._prob_step = 0
|
69
64
|
|
70
65
|
def _step(self, group):
|
71
|
-
# update preconditioners all together
|
72
|
-
update_prob = self.preconditioner_update_probability
|
73
|
-
if callable(update_prob):
|
74
|
-
update_prob = update_prob(self._prob_step)
|
75
|
-
do_update = self.rng.random() < update_prob
|
76
|
-
self._prob_step += 1
|
77
|
-
|
78
66
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
79
67
|
precond_init_scale = group['precond_init_scale']
|
80
68
|
max_size_triangular = group['max_size_triangular']
|
@@ -128,11 +116,10 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
116
|
|
129
117
|
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
130
118
|
|
131
|
-
if
|
119
|
+
if self.should_update(group):
|
132
120
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
133
121
|
q32 = [promote(q_) for q_ in q]
|
134
|
-
self.
|
135
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
122
|
+
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
136
123
|
for c_, q_ in zip(cached_q, q):
|
137
124
|
if q_.ndim == 2:
|
138
125
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
heavyball/delayed_psgd.py
CHANGED
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32'):
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
|
43
43
|
if not 0.0 <= lr:
|
44
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
45
|
if not 0.0 <= beta < 1.0:
|
@@ -47,12 +47,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
47
47
|
if not 0.0 <= weight_decay:
|
48
48
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
49
49
|
|
50
|
-
if preconditioner_update_probability is None:
|
51
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
52
50
|
if clip_fn is None:
|
53
51
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
54
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
-
self.clip_fn = clip_fn
|
56
52
|
|
57
53
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
54
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
@@ -61,18 +57,10 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
61
57
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
58
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
59
|
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
-
super().__init__(params, defaults, foreach)
|
60
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
65
61
|
|
66
|
-
self._prob_step = 0
|
67
62
|
|
68
63
|
def _step(self, group):
|
69
|
-
# update preconditioners all together
|
70
|
-
update_prob = self.preconditioner_update_probability
|
71
|
-
if callable(update_prob):
|
72
|
-
update_prob = update_prob(self._prob_step)
|
73
|
-
do_update = self.rng.random() < update_prob
|
74
|
-
self._prob_step += 1
|
75
|
-
|
76
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
77
65
|
precond_init_scale = group['precond_init_scale']
|
78
66
|
max_size_triangular = group['max_size_triangular']
|
@@ -114,10 +102,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
114
102
|
ea = exp_avg_list.pop(0)
|
115
103
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
116
104
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
117
|
-
if
|
105
|
+
if self.should_update(group):
|
118
106
|
q32 = [promote(q_) for q_ in q]
|
119
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line
|
120
|
-
self.balance([g], [q32])
|
107
|
+
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
121
108
|
set_(g, new)
|
122
109
|
|
123
110
|
grad_list = self.clip_fn(grad_list)
|
heavyball/p_adam.py
CHANGED
@@ -5,7 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import triu_to_line, line_to_triu
|
8
|
+
from heavyball.utils import triu_to_line, line_to_triu, identity
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
11
11
|
exp_avg_sq_, beta_debias, split_p_and_g_in_group, promote
|
@@ -38,8 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
-
store_triu_as_line: bool = True,
|
42
|
-
|
41
|
+
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
+
stochastic_schedule: bool = True):
|
43
43
|
if not 0.0 <= lr:
|
44
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
45
|
if not 0.0 <= weight_decay:
|
@@ -47,12 +47,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
47
47
|
if betas[0] is not None:
|
48
48
|
beta = betas[0]
|
49
49
|
|
50
|
-
if preconditioner_update_probability is None:
|
51
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
52
50
|
if clip_fn is None:
|
53
|
-
clip_fn =
|
54
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
-
self.clip_fn = clip_fn
|
51
|
+
clip_fn = identity
|
56
52
|
|
57
53
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
54
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
@@ -61,18 +57,9 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
61
57
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
58
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
63
59
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
-
super().__init__(params, defaults, foreach)
|
65
|
-
|
66
|
-
self._prob_step = 0
|
60
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
67
61
|
|
68
62
|
def _step(self, group):
|
69
|
-
# update preconditioners all together
|
70
|
-
update_prob = self.preconditioner_update_probability
|
71
|
-
if callable(update_prob):
|
72
|
-
update_prob = update_prob(self._prob_step)
|
73
|
-
do_update = self.rng.random() < update_prob
|
74
|
-
self._prob_step += 1
|
75
|
-
|
76
63
|
precond_init_scale = group['precond_init_scale']
|
77
64
|
max_size_triangular = group['max_size_triangular']
|
78
65
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -91,8 +78,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
91
78
|
if 'Q' not in state:
|
92
79
|
state['exp_avg'] = torch.zeros_like(g)
|
93
80
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
94
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
95
|
-
|
81
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
82
|
+
memory_save_mode, dtype=q_dtype)
|
96
83
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
97
84
|
|
98
85
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
@@ -106,11 +93,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
106
93
|
group["step"] += 1
|
107
94
|
|
108
95
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
109
|
-
if
|
96
|
+
if self.should_update(group):
|
110
97
|
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
111
98
|
q32 = [promote(qq_) for qq_ in q_]
|
112
|
-
self.
|
113
|
-
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
99
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
114
100
|
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
115
101
|
|
116
102
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
heavyball/psgd_kron.py
CHANGED
@@ -39,7 +39,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32'):
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
|
43
43
|
if not 0.0 <= lr:
|
44
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
45
|
if not 0.0 <= beta < 1.0:
|
@@ -47,12 +47,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
47
47
|
if not 0.0 <= weight_decay:
|
48
48
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
49
49
|
|
50
|
-
if preconditioner_update_probability is None:
|
51
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
52
50
|
if clip_fn is None:
|
53
51
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
54
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
-
self.clip_fn = clip_fn
|
56
52
|
|
57
53
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
54
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
@@ -61,18 +57,10 @@ class ForeachPSGDKron(PSGDBase):
|
|
61
57
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
58
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
59
|
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
-
super().__init__(params, defaults, foreach)
|
60
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
65
61
|
|
66
|
-
self._prob_step = 0
|
67
62
|
|
68
63
|
def _step(self, group):
|
69
|
-
# update preconditioners all together
|
70
|
-
update_prob = self.preconditioner_update_probability
|
71
|
-
if callable(update_prob):
|
72
|
-
update_prob = update_prob(self._prob_step)
|
73
|
-
do_update = self.rng.random() < update_prob
|
74
|
-
self._prob_step += 1
|
75
|
-
|
76
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
77
65
|
precond_init_scale = group['precond_init_scale']
|
78
66
|
max_size_triangular = group['max_size_triangular']
|
@@ -114,10 +102,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
114
102
|
ea = exp_avg_list.pop(0)
|
115
103
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
116
104
|
|
117
|
-
if
|
105
|
+
if self.should_update(group):
|
118
106
|
q32 = [promote(q_) for q_ in q]
|
119
|
-
self.
|
120
|
-
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
107
|
+
self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
121
108
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
122
109
|
|
123
110
|
grad_list = self.clip_fn(grad_list)
|
heavyball/pure_psgd.py
CHANGED
@@ -5,7 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_
|
8
|
+
from heavyball.utils import copy_stochastic_list_, identity
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
11
11
|
split_p_and_g_in_group, line_to_triu, triu_to_line, promote
|
@@ -38,18 +38,14 @@ class ForeachPurePSGD(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
40
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
41
|
-
foreach: bool = True, q_dtype='float32'):
|
41
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True):
|
42
42
|
if not 0.0 <= lr:
|
43
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
44
|
if not 0.0 <= weight_decay:
|
45
45
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
46
46
|
|
47
|
-
if preconditioner_update_probability is None:
|
48
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
49
47
|
if clip_fn is None:
|
50
|
-
clip_fn =
|
51
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
52
|
-
self.clip_fn = clip_fn
|
48
|
+
clip_fn = identity
|
53
49
|
|
54
50
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
55
51
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
@@ -58,18 +54,10 @@ class ForeachPurePSGD(PSGDBase):
|
|
58
54
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
59
55
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
60
56
|
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
|
-
super().__init__(params, defaults, foreach)
|
62
|
-
|
63
|
-
self._prob_step = 0
|
57
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
58
|
|
65
59
|
def _step(self, group):
|
66
60
|
# update preconditioners all together
|
67
|
-
update_prob = self.preconditioner_update_probability
|
68
|
-
if callable(update_prob):
|
69
|
-
update_prob = update_prob(self._prob_step)
|
70
|
-
do_update = self.rng.random() < update_prob
|
71
|
-
self._prob_step += 1
|
72
|
-
|
73
61
|
precond_init_scale = group['precond_init_scale']
|
74
62
|
max_size_triangular = group['max_size_triangular']
|
75
63
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -105,10 +93,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
105
93
|
q_orig = Q_list.pop(0)
|
106
94
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
107
95
|
|
108
|
-
if
|
96
|
+
if self.should_update(group):
|
109
97
|
q32 = [promote(q_) for q_ in q]
|
110
|
-
self.
|
111
|
-
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
98
|
+
self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
112
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
113
100
|
|
114
101
|
grad_list = self.clip_fn(grad_list)
|
heavyball/utils.py
CHANGED
@@ -668,7 +668,10 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
668
668
|
|
669
669
|
term2 += term1 # a + b
|
670
670
|
term1 *= 2 # 2a
|
671
|
-
term1
|
671
|
+
if term1.dtype == term2.dtype:
|
672
|
+
term1 -= term2 # 2a - (a + b) == a - b
|
673
|
+
else:
|
674
|
+
term1 = term1 - term2
|
672
675
|
|
673
676
|
term1 *= step
|
674
677
|
norm = term2.norm(float('inf'))
|
@@ -790,21 +793,35 @@ def update_triu_(q_state, materialised):
|
|
790
793
|
class PSGDBase(StatefulOptimizer):
|
791
794
|
balance_probability: float = 0.01
|
792
795
|
|
793
|
-
def __init__(self, parameters, groups, foreach: bool
|
794
|
-
|
796
|
+
def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
|
797
|
+
preconditioner_update_probability):
|
798
|
+
super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
|
795
799
|
self.rng = random.Random(0x1923213)
|
796
800
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
797
|
-
|
798
|
-
|
799
|
-
if
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
801
|
+
if clip_fn is None:
|
802
|
+
clip_fn = identity
|
803
|
+
if preconditioner_update_probability is None:
|
804
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
805
|
+
self.clip_fn = clip_fn
|
806
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
807
|
+
|
808
|
+
def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
|
809
|
+
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
810
|
+
if prob is None:
|
811
|
+
prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
|
812
|
+
if group['stochastic_schedule']:
|
813
|
+
return self.rng.random() < prob
|
814
|
+
cumulative_prob = group.get(name, 0)
|
815
|
+
group[name] = cumulative_prob + prob
|
816
|
+
return int(group[name]) > int(cumulative_prob)
|
817
|
+
|
818
|
+
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
807
819
|
store_triu_as_line=False):
|
820
|
+
if self.should_update(group, self.balance_probability, 'balance_prob'):
|
821
|
+
for g, q in zip(grad_list, q_list):
|
822
|
+
if g.dim() > 1:
|
823
|
+
psgd_balance_Q(q)
|
824
|
+
|
808
825
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
809
826
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
810
827
|
if original_q:
|
@@ -1,24 +1,24 @@
|
|
1
1
|
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=
|
4
|
-
heavyball/delayed_psgd.py,sha256=
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=JKNgqgT59Aa9evpCG--mOZcYC0qBqLntNU7uQYYWURM,6487
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=RCvXEiA-WwCERgndquaD-GU5L1wB4aeffkVDsbytaV4,6482
|
4
|
+
heavyball/delayed_psgd.py,sha256=f3nX-6N6BQtZ3uIdZc7w2uJSphTJb4V1ocWpY35VdSU,5525
|
5
5
|
heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
|
6
6
|
heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
|
7
7
|
heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
|
9
9
|
heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
|
10
|
-
heavyball/p_adam.py,sha256=
|
10
|
+
heavyball/p_adam.py,sha256=lvPsG6tjV-NMY8lYTLVwmqTDz_BJg8SwbwQQrm55YlM,5849
|
11
11
|
heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
|
12
12
|
heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
|
13
13
|
heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
|
14
14
|
heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
|
15
15
|
heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
|
16
|
-
heavyball/psgd_kron.py,sha256=
|
17
|
-
heavyball/pure_psgd.py,sha256=
|
16
|
+
heavyball/psgd_kron.py,sha256=xBXO2rdzPtxcGjb5xDWPcFKixVoO9nwaaUUWckq1rBI,5466
|
17
|
+
heavyball/pure_psgd.py,sha256=uMIQsslIbXuJ_uHJ9_bQD23mPv1IG0CZD6eA6W_AJ6g,4901
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.
|
21
|
-
heavyball-0.
|
22
|
-
heavyball-0.
|
23
|
-
heavyball-0.
|
24
|
-
heavyball-0.
|
19
|
+
heavyball/utils.py,sha256=sojDNo94l-jAPDSnhEM5EI6D83SHG_hTnJOkdnDquSI,30492
|
20
|
+
heavyball-0.18.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.18.0.dist-info/METADATA,sha256=_IjP5WGcuqV-ryQyy7pW9b5BRHxIJtsMFTRWuRDDs3o,11810
|
22
|
+
heavyball-0.18.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.18.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.18.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|