heavyball 0.21.8__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +6 -5
- heavyball/cached_delayed_psgd_kron.py +6 -5
- heavyball/cached_psgd_kron.py +7 -5
- heavyball/delayed_psgd.py +14 -11
- heavyball/foreach_adamw.py +14 -7
- heavyball/foreach_adopt.py +11 -6
- heavyball/foreach_laprop.py +12 -6
- heavyball/foreach_sfadamw.py +10 -3
- heavyball/foreach_soap.py +10 -8
- heavyball/p_adam.py +11 -9
- heavyball/palm_foreach_sfadamw.py +11 -3
- heavyball/palm_foreach_soap.py +8 -9
- heavyball/precond_schedule_foreach_soap.py +10 -8
- heavyball/precond_schedule_palm_foreach_soap.py +9 -9
- heavyball/precond_schedule_sfpsoap.py +10 -5
- heavyball/psgd_kron.py +9 -6
- heavyball/pure_psgd.py +11 -7
- heavyball/schedule_free_palm_foreach_soap.py +13 -5
- heavyball/utils.py +171 -106
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/METADATA +2 -2
- heavyball-0.23.0.dist-info/RECORD +24 -0
- heavyball-0.21.8.dist-info/RECORD +0 -24
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/LICENSE +0 -0
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/WHEEL +0 -0
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
|
6
|
-
precond_schedule, set_,
|
6
|
+
precond_schedule, set_, StatefulOptimizer
|
7
7
|
|
8
8
|
|
9
9
|
class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
@@ -33,14 +33,15 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
35
|
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
36
|
-
foreach: bool = True):
|
36
|
+
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
40
40
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
41
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
42
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
43
|
-
'beta2_scale': beta2_scale, 'split': split
|
43
|
+
'beta2_scale': beta2_scale, 'split': split, 'mars': mars, 'caution': caution,
|
44
|
+
'mars_gamma': mars_gamma}
|
44
45
|
super().__init__(params, defaults, foreach)
|
45
46
|
self._data_format = data_format
|
46
47
|
self.rng = random.Random(0x120983109)
|
@@ -52,7 +53,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
52
53
|
max_precond_dim = group['max_precond_dim']
|
53
54
|
precondition_1d = group['precondition_1d']
|
54
55
|
|
55
|
-
for p, g in split_p_and_g_in_group(group):
|
56
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
56
57
|
state = self.state_(p)
|
57
58
|
step = state['step'] = state.get("step", -1) + 1
|
58
59
|
|
@@ -86,6 +87,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
86
87
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
87
88
|
|
88
89
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
90
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
|
+
|
89
92
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
90
93
|
state = self.state_(p)
|
91
94
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
@@ -96,10 +99,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
96
99
|
# to the original space
|
97
100
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
98
101
|
exp_avg_projected = exp_avg_projected / d
|
99
|
-
|
102
|
+
precond = project(exp_avg_projected, state['Q'], True)
|
100
103
|
|
101
104
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
102
|
-
|
103
|
-
# Why does this have to be rebiased here?
|
104
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
105
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
105
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -3,11 +3,11 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
-
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule,
|
6
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, copy_stochastic_list_, \
|
7
7
|
promote
|
8
8
|
|
9
9
|
|
10
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
10
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
11
11
|
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
12
12
|
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
13
13
|
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
@@ -52,15 +52,20 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
52
52
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
53
53
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
54
54
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
|
55
|
-
split: bool = False, foreach: bool = True
|
55
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
56
|
+
mars_gamma: float = 0.0025):
|
56
57
|
if betas[0] is not None:
|
57
58
|
beta = betas[0]
|
59
|
+
|
60
|
+
assert not caution, "Caution is not implemented in ScheduleFree optimizers"
|
61
|
+
|
58
62
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
59
63
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
60
64
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
61
65
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
62
66
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
|
63
|
-
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split
|
67
|
+
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split,
|
68
|
+
'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
64
69
|
super().__init__(params, defaults, foreach)
|
65
70
|
self._data_format = data_format
|
66
71
|
self.rng = random.Random(0x120983109)
|
@@ -87,7 +92,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
87
92
|
# adaptive gradient clipping
|
88
93
|
adaptive_gradient_clipping_(p_list, grad, group["gradient_clip_val"], eps=group["eps"])
|
89
94
|
|
90
|
-
for p, g in split_p_and_g_in_group(group):
|
95
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
91
96
|
state = self.state_(p)
|
92
97
|
|
93
98
|
if "z" not in state:
|
heavyball/psgd_kron.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
|
12
|
+
line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -40,7 +40,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
42
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
-
storage_dtype: str = 'float32',
|
43
|
+
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
+
#
|
44
45
|
# expert parameters
|
45
46
|
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
@@ -57,7 +58,9 @@ class ForeachPSGDKron(PSGDBase):
|
|
57
58
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
58
59
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
59
60
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
60
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
62
|
+
storage_dtype=storage_dtype,
|
63
|
+
mars=mars, caution=caution, mars_gamma=mars_gamma)
|
61
64
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
65
|
|
63
66
|
def _step(self, group):
|
@@ -77,7 +80,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
77
80
|
|
78
81
|
vals = []
|
79
82
|
|
80
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
83
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
81
84
|
state = self.state_(p)
|
82
85
|
|
83
86
|
if 'Q' not in state:
|
@@ -113,5 +116,5 @@ class ForeachPSGDKron(PSGDBase):
|
|
113
116
|
q32 = [promote(q_) for q_ in q]
|
114
117
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
115
118
|
store_triu_as_line)
|
116
|
-
g = psgd_precond_grad(
|
117
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
119
|
+
g = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *q)
|
120
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|
heavyball/pure_psgd.py
CHANGED
@@ -5,9 +5,9 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import identity
|
10
|
-
|
9
|
+
|
10
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, \
|
11
11
|
line_to_triu, triu_to_line, promote
|
12
12
|
|
13
13
|
|
@@ -38,7 +38,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
40
40
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
41
|
-
q_dtype='float32', stochastic_schedule: bool = True,
|
41
|
+
q_dtype='float32', stochastic_schedule: bool = True, mars: bool = False, caution: bool = False,
|
42
|
+
mars_gamma: float = 0.0025, #
|
42
43
|
# expert parameters
|
43
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
44
45
|
if not 0.0 <= lr:
|
@@ -49,11 +50,14 @@ class ForeachPurePSGD(PSGDBase):
|
|
49
50
|
if clip_fn is None:
|
50
51
|
clip_fn = identity
|
51
52
|
|
53
|
+
assert not mars, "MARS is not supported in this optimizer"
|
54
|
+
|
52
55
|
defaults = dict(lr=lr, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
53
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
54
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
55
58
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
56
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, mars=mars, caution=caution,
|
60
|
+
mars_gamma=mars_gamma)
|
57
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
58
62
|
|
59
63
|
def _step(self, group):
|
@@ -70,7 +74,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
70
74
|
|
71
75
|
vals = []
|
72
76
|
|
73
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
77
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=0.0):
|
74
78
|
state = self.state_(p)
|
75
79
|
|
76
80
|
if 'Q' not in state:
|
@@ -97,5 +101,5 @@ class ForeachPurePSGD(PSGDBase):
|
|
97
101
|
if group:
|
98
102
|
q32 = [promote(q_) for q_ in q]
|
99
103
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
100
|
-
psgd_precond_grad(
|
101
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
104
|
+
psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *q)
|
105
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|
@@ -1,12 +1,13 @@
|
|
1
1
|
import random
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from heavyball.utils import mars_correction
|
4
5
|
|
5
6
|
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
-
beta_debias, schedule_free_, warmup, ScheduleFree,
|
7
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, copy_stochastic_list_, promote
|
7
8
|
|
8
9
|
|
9
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
10
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
10
11
|
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
11
12
|
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
12
13
|
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
@@ -44,15 +45,19 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
44
45
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
45
46
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
46
47
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
|
47
|
-
foreach: bool = True):
|
48
|
+
foreach: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
48
49
|
if betas[0] is not None:
|
49
50
|
beta = betas[0]
|
51
|
+
|
52
|
+
assert not caution, "Caution is not implemented in ScheduleFree optimizers"
|
53
|
+
|
50
54
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
51
55
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
52
56
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
53
57
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
54
58
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
|
55
|
-
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split
|
59
|
+
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split, 'mars': mars,
|
60
|
+
'caution': caution, 'mars_gamma': mars_gamma}
|
56
61
|
super().__init__(params, defaults, foreach)
|
57
62
|
self._data_format = data_format
|
58
63
|
self.rng = random.Random(0x120983109)
|
@@ -61,6 +66,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
61
66
|
vals = []
|
62
67
|
max_precond_dim = group['max_precond_dim']
|
63
68
|
precondition_1d = group['precondition_1d']
|
69
|
+
mars = group['mars']
|
64
70
|
|
65
71
|
step = group['step'] = group.get("step", 0) + 1
|
66
72
|
|
@@ -79,12 +85,14 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
79
85
|
|
80
86
|
vals = []
|
81
87
|
|
82
|
-
for p, g in split_p_and_g_in_group(group):
|
88
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
83
89
|
state = self.state_(p)
|
84
90
|
|
85
91
|
if "z" not in state:
|
86
92
|
state["z"] = torch.clone(p).float()
|
87
93
|
state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
|
94
|
+
if mars:
|
95
|
+
state['mars_prev_grad'] = g.clone()
|
88
96
|
init_preconditioner(g, state, max_precond_dim, precondition_1d)
|
89
97
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
|
90
98
|
continue # first step is skipped so that we never use the current gradients in the projection.
|