heavyball 0.17.3__tar.gz → 0.18.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.17.3 → heavyball-0.18.1}/PKG-INFO +1 -1
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/cached_delayed_psgd_kron.py +12 -27
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/cached_psgd_kron.py +12 -27
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/delayed_psgd.py +8 -20
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/p_adam.py +17 -30
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/psgd_kron.py +10 -24
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/pure_psgd.py +14 -27
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/utils.py +26 -12
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/SOURCES.txt +2 -1
- {heavyball-0.17.3 → heavyball-0.18.1}/setup.py +1 -1
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_bf16_q.py +1 -1
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_foreach.py +1 -1
- heavyball-0.18.1/test/test_stochastic_updates.py +52 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/LICENSE +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/README.md +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/__init__.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/setup.cfg +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_closure.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_memory.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_merge.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_no_grad.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_psgd.py +0 -0
- {heavyball-0.17.3 → heavyball-0.18.1}/test/test_soap.py +0 -0
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from heavyball.utils import einsum_base
|
11
10
|
|
12
|
-
from .utils import update_param_, warmup,
|
13
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
+
line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
13
|
|
15
14
|
|
16
15
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -42,7 +41,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
42
41
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
42
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
43
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
-
foreach: bool = True, q_dtype='float32'
|
44
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
45
|
+
# expert parameters
|
46
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
47
48
|
raise ValueError(f"Invalid learning rate: {lr}")
|
48
49
|
if not 0.0 <= beta < 1.0:
|
@@ -50,33 +51,17 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
50
51
|
if not 0.0 <= weight_decay:
|
51
52
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
52
53
|
|
53
|
-
if preconditioner_update_probability is None:
|
54
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
55
54
|
if clip_fn is None:
|
56
55
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
57
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
58
|
-
self.clip_fn = clip_fn
|
59
56
|
|
60
57
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
61
58
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
62
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
store_triu_as_line=store_triu_as_line,
|
67
|
-
q_dtype=q_dtype)
|
68
|
-
super().__init__(params, defaults, foreach)
|
69
|
-
|
70
|
-
self._prob_step = 0
|
59
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
62
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
71
63
|
|
72
64
|
def _step(self, group):
|
73
|
-
# update preconditioners all together
|
74
|
-
update_prob = self.preconditioner_update_probability
|
75
|
-
if callable(update_prob):
|
76
|
-
update_prob = update_prob(self._prob_step)
|
77
|
-
do_update = self.rng.random() < update_prob
|
78
|
-
self._prob_step += 1
|
79
|
-
|
80
65
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
81
66
|
precond_init_scale = group['precond_init_scale']
|
82
67
|
max_size_triangular = group['max_size_triangular']
|
@@ -128,11 +113,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
128
113
|
q_orig = Q_list.pop(0)
|
129
114
|
ea = exp_avg_list.pop(0)
|
130
115
|
|
131
|
-
if
|
116
|
+
if self.should_update(group):
|
132
117
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
133
118
|
q32 = [promote(q_) for q_ in q]
|
134
|
-
self.
|
135
|
-
|
119
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
120
|
+
store_triu_as_line)
|
136
121
|
for c_, q_ in zip(cached_q, q):
|
137
122
|
if q_.ndim == 2:
|
138
123
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from heavyball.utils import einsum_base
|
11
10
|
|
12
|
-
from .utils import update_param_, warmup,
|
13
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
13
|
|
15
14
|
|
16
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -40,7 +39,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
-
foreach: bool = True, q_dtype='float32'
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
44
45
|
if not 0.0 <= lr:
|
45
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
46
47
|
if not 0.0 <= beta < 1.0:
|
@@ -48,33 +49,17 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
48
49
|
if not 0.0 <= weight_decay:
|
49
50
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
50
51
|
|
51
|
-
if preconditioner_update_probability is None:
|
52
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
53
52
|
if clip_fn is None:
|
54
53
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
55
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
56
|
-
self.clip_fn = clip_fn
|
57
54
|
|
58
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
59
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
60
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
store_triu_as_line=store_triu_as_line,
|
65
|
-
q_dtype=q_dtype)
|
66
|
-
super().__init__(params, defaults, foreach)
|
67
|
-
|
68
|
-
self._prob_step = 0
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
69
61
|
|
70
62
|
def _step(self, group):
|
71
|
-
# update preconditioners all together
|
72
|
-
update_prob = self.preconditioner_update_probability
|
73
|
-
if callable(update_prob):
|
74
|
-
update_prob = update_prob(self._prob_step)
|
75
|
-
do_update = self.rng.random() < update_prob
|
76
|
-
self._prob_step += 1
|
77
|
-
|
78
63
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
79
64
|
precond_init_scale = group['precond_init_scale']
|
80
65
|
max_size_triangular = group['max_size_triangular']
|
@@ -128,11 +113,11 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
113
|
|
129
114
|
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
130
115
|
|
131
|
-
if
|
116
|
+
if self.should_update(group):
|
132
117
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
133
118
|
q32 = [promote(q_) for q_ in q]
|
134
|
-
self.
|
135
|
-
|
119
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
120
|
+
store_triu_as_line)
|
136
121
|
for c_, q_ in zip(cached_q, q):
|
137
122
|
if q_.ndim == 2:
|
138
123
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
@@ -39,7 +39,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32'
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= beta < 1.0:
|
@@ -47,32 +49,19 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
47
49
|
if not 0.0 <= weight_decay:
|
48
50
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
49
51
|
|
50
|
-
if preconditioner_update_probability is None:
|
51
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
52
52
|
if clip_fn is None:
|
53
53
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
54
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
-
self.clip_fn = clip_fn
|
56
54
|
|
57
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
60
|
-
|
61
|
-
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale,
|
62
59
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
60
|
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
-
super().__init__(params, defaults, foreach)
|
61
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
65
62
|
|
66
|
-
self._prob_step = 0
|
67
63
|
|
68
64
|
def _step(self, group):
|
69
|
-
# update preconditioners all together
|
70
|
-
update_prob = self.preconditioner_update_probability
|
71
|
-
if callable(update_prob):
|
72
|
-
update_prob = update_prob(self._prob_step)
|
73
|
-
do_update = self.rng.random() < update_prob
|
74
|
-
self._prob_step += 1
|
75
|
-
|
76
65
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
77
66
|
precond_init_scale = group['precond_init_scale']
|
78
67
|
max_size_triangular = group['max_size_triangular']
|
@@ -114,10 +103,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
114
103
|
ea = exp_avg_list.pop(0)
|
115
104
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
116
105
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
117
|
-
if
|
106
|
+
if self.should_update(group):
|
118
107
|
q32 = [promote(q_) for q_ in q]
|
119
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line
|
120
|
-
self.balance([g], [q32])
|
108
|
+
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
121
109
|
set_(g, new)
|
122
110
|
|
123
111
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import triu_to_line, line_to_triu
|
9
8
|
|
10
|
-
from .utils import
|
11
|
-
|
9
|
+
from heavyball.utils import triu_to_line, line_to_triu, identity
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
|
+
split_p_and_g_in_group, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -38,8 +38,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
-
store_triu_as_line: bool = True,
|
42
|
-
|
41
|
+
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
+
stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= weight_decay:
|
@@ -47,32 +49,18 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
47
49
|
if betas[0] is not None:
|
48
50
|
beta = betas[0]
|
49
51
|
|
50
|
-
if preconditioner_update_probability is None:
|
51
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
52
52
|
if clip_fn is None:
|
53
|
-
clip_fn =
|
54
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
-
self.clip_fn = clip_fn
|
53
|
+
clip_fn = identity
|
56
54
|
|
57
55
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
super().__init__(params, defaults, foreach)
|
65
|
-
|
66
|
-
self._prob_step = 0
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
|
+
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
67
62
|
|
68
63
|
def _step(self, group):
|
69
|
-
# update preconditioners all together
|
70
|
-
update_prob = self.preconditioner_update_probability
|
71
|
-
if callable(update_prob):
|
72
|
-
update_prob = update_prob(self._prob_step)
|
73
|
-
do_update = self.rng.random() < update_prob
|
74
|
-
self._prob_step += 1
|
75
|
-
|
76
64
|
precond_init_scale = group['precond_init_scale']
|
77
65
|
max_size_triangular = group['max_size_triangular']
|
78
66
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -91,8 +79,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
91
79
|
if 'Q' not in state:
|
92
80
|
state['exp_avg'] = torch.zeros_like(g)
|
93
81
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
94
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
95
|
-
|
82
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
83
|
+
memory_save_mode, dtype=q_dtype)
|
96
84
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
97
85
|
|
98
86
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
@@ -106,11 +94,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
106
94
|
group["step"] += 1
|
107
95
|
|
108
96
|
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
109
|
-
if
|
97
|
+
if self.should_update(group):
|
110
98
|
for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
|
111
99
|
q32 = [promote(qq_) for qq_ in q_]
|
112
|
-
self.
|
113
|
-
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
100
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
114
101
|
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
115
102
|
|
116
103
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -39,7 +39,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32'
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= beta < 1.0:
|
@@ -47,32 +49,17 @@ class ForeachPSGDKron(PSGDBase):
|
|
47
49
|
if not 0.0 <= weight_decay:
|
48
50
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
49
51
|
|
50
|
-
if preconditioner_update_probability is None:
|
51
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
52
52
|
if clip_fn is None:
|
53
53
|
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
54
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
-
self.clip_fn = clip_fn
|
56
54
|
|
57
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
64
|
-
super().__init__(params, defaults, foreach)
|
65
|
-
|
66
|
-
self._prob_step = 0
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
67
61
|
|
68
62
|
def _step(self, group):
|
69
|
-
# update preconditioners all together
|
70
|
-
update_prob = self.preconditioner_update_probability
|
71
|
-
if callable(update_prob):
|
72
|
-
update_prob = update_prob(self._prob_step)
|
73
|
-
do_update = self.rng.random() < update_prob
|
74
|
-
self._prob_step += 1
|
75
|
-
|
76
63
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
77
64
|
precond_init_scale = group['precond_init_scale']
|
78
65
|
max_size_triangular = group['max_size_triangular']
|
@@ -114,10 +101,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
114
101
|
ea = exp_avg_list.pop(0)
|
115
102
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
116
103
|
|
117
|
-
if
|
104
|
+
if self.should_update(group):
|
118
105
|
q32 = [promote(q_) for q_ in q]
|
119
|
-
self.
|
120
|
-
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
106
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
121
107
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
122
108
|
|
123
109
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_
|
9
8
|
|
10
|
-
from .utils import
|
11
|
-
|
9
|
+
from heavyball.utils import identity
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
|
11
|
+
line_to_triu, triu_to_line, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPurePSGD(PSGDBase):
|
@@ -37,39 +37,27 @@ class ForeachPurePSGD(PSGDBase):
|
|
37
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
41
|
-
|
40
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
41
|
+
q_dtype='float32', stochastic_schedule: bool = True, #
|
42
|
+
# expert parameters
|
43
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
42
44
|
if not 0.0 <= lr:
|
43
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
46
|
if not 0.0 <= weight_decay:
|
45
47
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
46
48
|
|
47
|
-
if preconditioner_update_probability is None:
|
48
|
-
preconditioner_update_probability = precond_update_prob_schedule()
|
49
49
|
if clip_fn is None:
|
50
|
-
clip_fn =
|
51
|
-
self.preconditioner_update_probability = preconditioner_update_probability
|
52
|
-
self.clip_fn = clip_fn
|
50
|
+
clip_fn = identity
|
53
51
|
|
54
52
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
55
53
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
56
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
|
-
super().__init__(params, defaults, foreach)
|
62
|
-
|
63
|
-
self._prob_step = 0
|
54
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
55
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
56
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
|
+
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
58
|
|
65
59
|
def _step(self, group):
|
66
60
|
# update preconditioners all together
|
67
|
-
update_prob = self.preconditioner_update_probability
|
68
|
-
if callable(update_prob):
|
69
|
-
update_prob = update_prob(self._prob_step)
|
70
|
-
do_update = self.rng.random() < update_prob
|
71
|
-
self._prob_step += 1
|
72
|
-
|
73
61
|
precond_init_scale = group['precond_init_scale']
|
74
62
|
max_size_triangular = group['max_size_triangular']
|
75
63
|
min_ndim_triangular = group['min_ndim_triangular']
|
@@ -105,10 +93,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
105
93
|
q_orig = Q_list.pop(0)
|
106
94
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
107
95
|
|
108
|
-
if
|
96
|
+
if self.should_update(group):
|
109
97
|
q32 = [promote(q_) for q_ in q]
|
110
|
-
self.
|
111
|
-
self.do_update([p], [g], [q32], precond_lr, [q_orig], store_triu_as_line=store_triu_as_line)
|
98
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
112
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
113
100
|
|
114
101
|
grad_list = self.clip_fn(grad_list)
|
@@ -793,21 +793,35 @@ def update_triu_(q_state, materialised):
|
|
793
793
|
class PSGDBase(StatefulOptimizer):
|
794
794
|
balance_probability: float = 0.01
|
795
795
|
|
796
|
-
def __init__(self, parameters, groups, foreach: bool
|
797
|
-
|
796
|
+
def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
|
797
|
+
preconditioner_update_probability):
|
798
|
+
super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
|
798
799
|
self.rng = random.Random(0x1923213)
|
799
800
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
800
|
-
|
801
|
-
|
802
|
-
if
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
801
|
+
if clip_fn is None:
|
802
|
+
clip_fn = identity
|
803
|
+
if preconditioner_update_probability is None:
|
804
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
805
|
+
self.clip_fn = clip_fn
|
806
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
807
|
+
|
808
|
+
def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
|
809
|
+
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
810
|
+
if prob is None:
|
811
|
+
prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
|
812
|
+
if group['stochastic_schedule']:
|
813
|
+
return self.rng.random() < prob
|
814
|
+
cumulative_prob = group.get(name, 0)
|
815
|
+
group[name] = cumulative_prob + prob
|
816
|
+
return int(group[name]) > int(cumulative_prob)
|
817
|
+
|
818
|
+
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
810
819
|
store_triu_as_line=False):
|
820
|
+
if self.should_update(group, self.balance_probability, 'balance_prob'):
|
821
|
+
for g, q in zip(grad_list, q_list):
|
822
|
+
if g.dim() > 1:
|
823
|
+
psgd_balance_Q(q)
|
824
|
+
|
811
825
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
812
826
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
813
827
|
if original_q:
|
@@ -28,11 +28,11 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations:
|
|
28
28
|
losses = []
|
29
29
|
|
30
30
|
for q_dtype in ['float32', 'bfloat16']:
|
31
|
+
torch.manual_seed(0x2131290)
|
31
32
|
peaks.append([])
|
32
33
|
losses.append([])
|
33
34
|
|
34
35
|
for i in range(outer_iterations):
|
35
|
-
torch.manual_seed(0x2131290)
|
36
36
|
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
37
37
|
o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype)
|
38
38
|
|
@@ -26,11 +26,11 @@ def test_foreach(opt, size, depth: int, iterations: int = 5, outer_iterations: i
|
|
26
26
|
losses = []
|
27
27
|
|
28
28
|
for foreach in [True, False]:
|
29
|
+
torch.manual_seed(0x2131290)
|
29
30
|
peaks.append([])
|
30
31
|
losses.append([])
|
31
32
|
|
32
33
|
for i in range(outer_iterations):
|
33
|
-
torch.manual_seed(0x2131290)
|
34
34
|
clean()
|
35
35
|
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
36
36
|
clean()
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import heavyball
|
2
|
+
import heavyball.utils
|
3
|
+
import pytest
|
4
|
+
import torch
|
5
|
+
from benchmark.utils import get_optim
|
6
|
+
from heavyball.utils import clean, set_torch, PSGDBase
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
|
10
|
+
def get_memory():
|
11
|
+
clean()
|
12
|
+
torch.cuda.synchronize()
|
13
|
+
clean()
|
14
|
+
torch.cuda.synchronize()
|
15
|
+
return torch.cuda.memory_allocated()
|
16
|
+
|
17
|
+
|
18
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
+
@pytest.mark.parametrize("size,depth", [(128, 2)])
|
20
|
+
def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
|
21
|
+
set_torch()
|
22
|
+
|
23
|
+
opt = getattr(heavyball, opt)
|
24
|
+
if not issubclass(opt, PSGDBase):
|
25
|
+
raise pytest.skip('Only PSGD is supported')
|
26
|
+
|
27
|
+
peaks = []
|
28
|
+
losses = []
|
29
|
+
|
30
|
+
for stochastic in [False, True]:
|
31
|
+
torch.manual_seed(0x2131290)
|
32
|
+
peaks.append([])
|
33
|
+
losses.append([])
|
34
|
+
|
35
|
+
for i in range(outer_iterations):
|
36
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
37
|
+
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
38
|
+
|
39
|
+
for _ in range(iterations):
|
40
|
+
loss = model(torch.randn((128, size)).cuda()).square().mean()
|
41
|
+
loss.backward()
|
42
|
+
o.step()
|
43
|
+
o.zero_grad()
|
44
|
+
losses[-1].append(loss.detach())
|
45
|
+
|
46
|
+
del model, o
|
47
|
+
clean()
|
48
|
+
|
49
|
+
stochastic = sum([l.item() for l in losses[1]])
|
50
|
+
deterministic = sum([l.item() for l in losses[0]])
|
51
|
+
print(f"{deterministic=}, {stochastic=}")
|
52
|
+
assert deterministic < stochastic
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|