heavyball 0.18.0__tar.gz → 0.18.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.0 → heavyball-0.18.2}/PKG-INFO +1 -1
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/cached_delayed_psgd_kron.py +11 -11
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/cached_psgd_kron.py +11 -15
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/delayed_psgd.py +5 -4
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/p_adam.py +10 -9
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/psgd_kron.py +8 -9
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/pure_psgd.py +11 -11
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/utils.py +7 -6
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.0 → heavyball-0.18.2}/setup.py +1 -1
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_stochastic_updates.py +4 -3
- {heavyball-0.18.0 → heavyball-0.18.2}/LICENSE +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/README.md +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/__init__.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/setup.cfg +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_closure.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_foreach.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_memory.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_merge.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_no_grad.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_psgd.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.2}/test/test_soap.py +0 -0
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from heavyball.utils import einsum_base
|
11
10
|
|
12
|
-
from .utils import update_param_, warmup,
|
13
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
+
line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
13
|
|
15
14
|
|
16
15
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -42,7 +41,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
42
41
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
42
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
43
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
44
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
45
|
+
# expert parameters
|
46
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
47
48
|
raise ValueError(f"Invalid learning rate: {lr}")
|
48
49
|
if not 0.0 <= beta < 1.0:
|
@@ -55,14 +56,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
55
56
|
|
56
57
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
57
58
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
58
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
59
|
-
|
60
|
-
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
59
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
63
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
63
|
|
65
|
-
|
66
64
|
def _step(self, group):
|
67
65
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
68
66
|
precond_init_scale = group['precond_init_scale']
|
@@ -115,6 +113,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
115
113
|
q_orig = Q_list.pop(0)
|
116
114
|
ea = exp_avg_list.pop(0)
|
117
115
|
|
116
|
+
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
117
|
+
|
118
118
|
if self.should_update(group):
|
119
119
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
120
120
|
q32 = [promote(q_) for q_ in q]
|
@@ -126,7 +126,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
126
126
|
else:
|
127
127
|
torch.mul(q_.conj(), q_, out=c_)
|
128
128
|
|
129
|
-
set_(g,
|
129
|
+
set_(g, new)
|
130
130
|
grad_list = self.clip_fn(grad_list)
|
131
131
|
|
132
132
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from heavyball.utils import einsum_base
|
11
10
|
|
12
|
-
from .utils import update_param_, warmup,
|
13
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
13
|
|
15
14
|
|
16
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -40,7 +39,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
44
45
|
if not 0.0 <= lr:
|
45
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
46
47
|
if not 0.0 <= beta < 1.0:
|
@@ -53,15 +54,11 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
53
54
|
|
54
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
55
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
56
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
57
|
-
|
58
|
-
|
59
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line,
|
61
|
-
q_dtype=q_dtype)
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
62
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
61
|
|
64
|
-
|
65
62
|
def _step(self, group):
|
66
63
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
67
64
|
precond_init_scale = group['precond_init_scale']
|
@@ -114,19 +111,18 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
114
111
|
q_orig = Q_list.pop(0)
|
115
112
|
ea = exp_avg_list.pop(0)
|
116
113
|
|
117
|
-
new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
118
|
-
|
119
114
|
if self.should_update(group):
|
120
115
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
121
116
|
q32 = [promote(q_) for q_ in q]
|
122
|
-
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
117
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
118
|
+
store_triu_as_line)
|
123
119
|
for c_, q_ in zip(cached_q, q):
|
124
120
|
if q_.ndim == 2:
|
125
121
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
126
122
|
else:
|
127
123
|
torch.mul(q_.conj(), q_, out=c_)
|
128
124
|
|
129
|
-
set_(g,
|
125
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
130
126
|
|
131
127
|
grad_list = self.clip_fn(grad_list)
|
132
128
|
|
@@ -39,7 +39,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= beta < 1.0:
|
@@ -52,9 +54,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
52
54
|
|
53
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
54
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
55
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
56
|
-
|
57
|
-
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale,
|
58
59
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
59
60
|
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import triu_to_line, line_to_triu, identity
|
9
8
|
|
10
|
-
from .utils import
|
11
|
-
|
9
|
+
from heavyball.utils import triu_to_line, line_to_triu, identity
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
|
+
split_p_and_g_in_group, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -39,7 +39,9 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True
|
42
|
+
stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= weight_decay:
|
@@ -52,11 +54,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
52
54
|
|
53
55
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
54
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
55
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
|
+
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
62
|
|
62
63
|
def _step(self, group):
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -39,7 +39,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= beta < 1.0:
|
@@ -52,14 +54,11 @@ class ForeachPSGDKron(PSGDBase):
|
|
52
54
|
|
53
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
54
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
55
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
56
|
-
|
57
|
-
|
58
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
59
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
61
|
|
62
|
-
|
63
62
|
def _step(self, group):
|
64
63
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
65
64
|
precond_init_scale = group['precond_init_scale']
|
@@ -104,7 +103,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
104
103
|
|
105
104
|
if self.should_update(group):
|
106
105
|
q32 = [promote(q_) for q_ in q]
|
107
|
-
self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
106
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
108
107
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
109
108
|
|
110
109
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_, identity
|
9
8
|
|
10
|
-
from .utils import
|
11
|
-
|
9
|
+
from heavyball.utils import identity
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
|
11
|
+
line_to_triu, triu_to_line, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPurePSGD(PSGDBase):
|
@@ -37,8 +37,10 @@ class ForeachPurePSGD(PSGDBase):
|
|
37
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
41
|
-
|
40
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
41
|
+
q_dtype='float32', stochastic_schedule: bool = True, #
|
42
|
+
# expert parameters
|
43
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
42
44
|
if not 0.0 <= lr:
|
43
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
46
|
if not 0.0 <= weight_decay:
|
@@ -49,11 +51,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
49
51
|
|
50
52
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
51
53
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
52
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
53
|
-
|
54
|
-
|
55
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
56
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
54
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
55
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
56
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
57
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
58
|
|
59
59
|
def _step(self, group):
|
@@ -95,7 +95,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
95
95
|
|
96
96
|
if self.should_update(group):
|
97
97
|
q32 = [promote(q_) for q_ in q]
|
98
|
-
self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
98
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
100
100
|
|
101
101
|
grad_list = self.clip_fn(grad_list)
|
@@ -817,14 +817,15 @@ class PSGDBase(StatefulOptimizer):
|
|
817
817
|
|
818
818
|
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
819
819
|
store_triu_as_line=False):
|
820
|
-
if self.should_update(group, self.balance_probability, 'balance_prob'):
|
821
|
-
for g, q in zip(grad_list, q_list):
|
822
|
-
if g.dim() > 1:
|
823
|
-
psgd_balance_Q(q)
|
824
|
-
|
825
820
|
for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
|
826
821
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
827
|
-
|
822
|
+
|
823
|
+
for g, q in zip(grad_list, q_list):
|
824
|
+
if g.dim() > 1:
|
825
|
+
psgd_balance_Q(q)
|
826
|
+
|
827
|
+
if original_q:
|
828
|
+
for q in q_list:
|
828
829
|
if store_triu_as_line:
|
829
830
|
update_triu_(original_q[i], Q)
|
830
831
|
else:
|
@@ -16,8 +16,8 @@ def get_memory():
|
|
16
16
|
|
17
17
|
|
18
18
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
|
-
@pytest.mark.parametrize("size,depth", [(128,
|
20
|
-
def test_foreach(opt, size, depth: int, iterations: int =
|
19
|
+
@pytest.mark.parametrize("size,depth", [(128, 1)])
|
20
|
+
def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
|
21
21
|
set_torch()
|
22
22
|
|
23
23
|
opt = getattr(heavyball, opt)
|
@@ -28,12 +28,13 @@ def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations
|
|
28
28
|
losses = []
|
29
29
|
|
30
30
|
for stochastic in [False, True]:
|
31
|
+
print('stochastic', stochastic)
|
31
32
|
torch.manual_seed(0x2131290)
|
32
33
|
peaks.append([])
|
33
34
|
losses.append([])
|
34
35
|
|
35
36
|
for i in range(outer_iterations):
|
36
|
-
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
|
37
38
|
o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic)
|
38
39
|
|
39
40
|
for _ in range(iterations):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|