heavyball 0.18.0__tar.gz → 0.18.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.0 → heavyball-0.18.1}/PKG-INFO +1 -1
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/cached_delayed_psgd_kron.py +8 -10
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/cached_psgd_kron.py +10 -12
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/delayed_psgd.py +5 -4
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/p_adam.py +10 -9
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/psgd_kron.py +8 -9
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/pure_psgd.py +11 -11
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.0 → heavyball-0.18.1}/setup.py +1 -1
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_stochastic_updates.py +1 -1
- {heavyball-0.18.0 → heavyball-0.18.1}/LICENSE +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/README.md +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/__init__.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball/utils.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/setup.cfg +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_closure.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_foreach.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_memory.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_merge.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_no_grad.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_psgd.py +0 -0
- {heavyball-0.18.0 → heavyball-0.18.1}/test/test_soap.py +0 -0
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from heavyball.utils import einsum_base
|
11
10
|
|
12
|
-
from .utils import update_param_, warmup,
|
13
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
+
line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
13
|
|
15
14
|
|
16
15
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -42,7 +41,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
42
41
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
42
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
43
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
44
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
45
|
+
# expert parameters
|
46
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
47
48
|
raise ValueError(f"Invalid learning rate: {lr}")
|
48
49
|
if not 0.0 <= beta < 1.0:
|
@@ -55,14 +56,11 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
55
56
|
|
56
57
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
57
58
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
58
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
59
|
-
|
60
|
-
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
59
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
63
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
63
|
|
65
|
-
|
66
64
|
def _step(self, group):
|
67
65
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
68
66
|
precond_init_scale = group['precond_init_scale']
|
@@ -7,10 +7,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
7
7
|
from typing import Optional
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from heavyball.utils import einsum_base
|
11
10
|
|
12
|
-
from .utils import update_param_, warmup,
|
13
|
-
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
|
14
13
|
|
15
14
|
|
16
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -40,7 +39,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
44
45
|
if not 0.0 <= lr:
|
45
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
46
47
|
if not 0.0 <= beta < 1.0:
|
@@ -53,15 +54,11 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
53
54
|
|
54
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
55
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
56
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
57
|
-
|
58
|
-
|
59
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line,
|
61
|
-
q_dtype=q_dtype)
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
62
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
61
|
|
64
|
-
|
65
62
|
def _step(self, group):
|
66
63
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
67
64
|
precond_init_scale = group['precond_init_scale']
|
@@ -119,7 +116,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
119
116
|
if self.should_update(group):
|
120
117
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
121
118
|
q32 = [promote(q_) for q_ in q]
|
122
|
-
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
119
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
120
|
+
store_triu_as_line)
|
123
121
|
for c_, q_ in zip(cached_q, q):
|
124
122
|
if q_.ndim == 2:
|
125
123
|
torch.matmul(q_.T.conj(), q_, out=c_)
|
@@ -39,7 +39,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= beta < 1.0:
|
@@ -52,9 +54,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
52
54
|
|
53
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
54
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
55
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
56
|
-
|
57
|
-
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale,
|
58
59
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
59
60
|
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import triu_to_line, line_to_triu, identity
|
9
8
|
|
10
|
-
from .utils import
|
11
|
-
|
9
|
+
from heavyball.utils import triu_to_line, line_to_triu, identity
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
|
+
split_p_and_g_in_group, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -39,7 +39,9 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True
|
42
|
+
stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= weight_decay:
|
@@ -52,11 +54,10 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
52
54
|
|
53
55
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
54
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
55
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
|
+
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
62
|
|
62
63
|
def _step(self, group):
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -39,7 +39,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
|
43
|
+
# expert parameters
|
44
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
43
45
|
if not 0.0 <= lr:
|
44
46
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
47
|
if not 0.0 <= beta < 1.0:
|
@@ -52,14 +54,11 @@ class ForeachPSGDKron(PSGDBase):
|
|
52
54
|
|
53
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
54
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
55
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
56
|
-
|
57
|
-
|
58
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
59
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
61
|
|
62
|
-
|
63
62
|
def _step(self, group):
|
64
63
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
65
64
|
precond_init_scale = group['precond_init_scale']
|
@@ -104,7 +103,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
104
103
|
|
105
104
|
if self.should_update(group):
|
106
105
|
q32 = [promote(q_) for q_ in q]
|
107
|
-
self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
106
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
108
107
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
109
108
|
|
110
109
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_, identity
|
9
8
|
|
10
|
-
from .utils import
|
11
|
-
|
9
|
+
from heavyball.utils import identity
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, split_p_and_g_in_group, \
|
11
|
+
line_to_triu, triu_to_line, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPurePSGD(PSGDBase):
|
@@ -37,8 +37,10 @@ class ForeachPurePSGD(PSGDBase):
|
|
37
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
41
|
-
|
40
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
41
|
+
q_dtype='float32', stochastic_schedule: bool = True, #
|
42
|
+
# expert parameters
|
43
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
42
44
|
if not 0.0 <= lr:
|
43
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
46
|
if not 0.0 <= weight_decay:
|
@@ -49,11 +51,9 @@ class ForeachPurePSGD(PSGDBase):
|
|
49
51
|
|
50
52
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
51
53
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
52
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=
|
53
|
-
|
54
|
-
|
55
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
56
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
54
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
55
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
56
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
57
57
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
58
|
|
59
59
|
def _step(self, group):
|
@@ -95,7 +95,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
95
95
|
|
96
96
|
if self.should_update(group):
|
97
97
|
q32 = [promote(q_) for q_ in q]
|
98
|
-
self.do_update(group,[p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
98
|
+
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
99
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
100
100
|
|
101
101
|
grad_list = self.clip_fn(grad_list)
|
@@ -17,7 +17,7 @@ def get_memory():
|
|
17
17
|
|
18
18
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
19
|
@pytest.mark.parametrize("size,depth", [(128, 2)])
|
20
|
-
def test_foreach(opt, size, depth: int, iterations: int =
|
20
|
+
def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3):
|
21
21
|
set_torch()
|
22
22
|
|
23
23
|
opt = getattr(heavyball, opt)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|