heavyball 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_psgd_kron.py +3 -2
- heavyball/delayed_psgd.py +5 -3
- heavyball/foreach_adamw.py +3 -2
- heavyball/foreach_adopt.py +3 -2
- heavyball/foreach_laprop.py +3 -2
- heavyball/foreach_sfadamw.py +4 -4
- heavyball/foreach_soap.py +4 -3
- heavyball/p_adam.py +4 -3
- heavyball/palm_foreach_sfadamw.py +3 -2
- heavyball/palm_foreach_soap.py +3 -2
- heavyball/precond_schedule_foreach_soap.py +3 -2
- heavyball/precond_schedule_palm_foreach_soap.py +3 -2
- heavyball/precond_schedule_sfpsoap.py +3 -3
- heavyball/psgd_kron.py +5 -3
- heavyball/pure_psgd.py +3 -2
- heavyball/schedule_free_palm_foreach_soap.py +4 -3
- heavyball/utils.py +23 -5
- {heavyball-0.15.1.dist-info → heavyball-0.16.0.dist-info}/METADATA +4 -2
- heavyball-0.16.0.dist-info/RECORD +23 -0
- heavyball-0.15.1.dist-info/RECORD +0 -23
- {heavyball-0.15.1.dist-info → heavyball-0.16.0.dist-info}/LICENSE +0 -0
- {heavyball-0.15.1.dist-info → heavyball-0.16.0.dist-info}/WHEEL +0 -0
- {heavyball-0.15.1.dist-info → heavyball-0.16.0.dist-info}/top_level.txt +0 -0
heavyball/cached_psgd_kron.py
CHANGED
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
39
39
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
40
40
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
41
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True
|
42
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
43
|
+
foreach: bool = True):
|
43
44
|
if not 0.0 <= lr:
|
44
45
|
raise ValueError(f"Invalid learning rate: {lr}")
|
45
46
|
if not 0.0 <= beta < 1.0:
|
@@ -61,7 +62,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
61
62
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
63
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
64
|
store_triu_as_line=store_triu_as_line)
|
64
|
-
super().__init__(params, defaults)
|
65
|
+
super().__init__(params, defaults, foreach)
|
65
66
|
|
66
67
|
self._prob_step = 0
|
67
68
|
|
heavyball/delayed_psgd.py
CHANGED
@@ -38,7 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= beta < 1.0:
|
@@ -60,7 +61,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
62
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
63
|
store_triu_as_line=store_triu_as_line)
|
63
|
-
super().__init__(params, defaults)
|
64
|
+
super().__init__(params, defaults, foreach)
|
64
65
|
|
65
66
|
self._prob_step = 0
|
66
67
|
|
@@ -113,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
113
114
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
114
115
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
115
116
|
if do_update:
|
116
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
117
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
118
|
+
[q_orig] if store_triu_as_line else None)
|
117
119
|
self.balance([g], [q])
|
118
120
|
set_(g, new)
|
119
121
|
|
heavyball/foreach_adamw.py
CHANGED
@@ -5,10 +5,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
|
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachAdamW(StatefulOptimizer):
|
8
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0
|
8
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
9
|
+
foreach: bool = True):
|
9
10
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
10
11
|
lr_max=-1.0, weight_decay=weight_decay)
|
11
|
-
super().__init__(params, defaults)
|
12
|
+
super().__init__(params, defaults, foreach)
|
12
13
|
|
13
14
|
def _step(self, group):
|
14
15
|
eps = group['eps']
|
heavyball/foreach_adopt.py
CHANGED
@@ -6,10 +6,11 @@ from .utils import warmup, beta_debias, update_param_, StatefulOptimizer
|
|
6
6
|
|
7
7
|
class ForeachADOPT(StatefulOptimizer):
|
8
8
|
|
9
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
10
|
+
foreach: bool = True):
|
10
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
11
12
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
|
-
super().__init__(params, defaults)
|
13
|
+
super().__init__(params, defaults, foreach)
|
13
14
|
|
14
15
|
def _step(self, group):
|
15
16
|
eps = group['eps']
|
heavyball/foreach_laprop.py
CHANGED
@@ -6,10 +6,11 @@ from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOpti
|
|
6
6
|
|
7
7
|
class ForeachLaProp(StatefulOptimizer):
|
8
8
|
|
9
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
10
|
+
foreach: bool = True):
|
10
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
11
12
|
lr_max=-1.0, weight_decay=weight_decay)
|
12
|
-
super().__init__(params, defaults)
|
13
|
+
super().__init__(params, defaults, foreach)
|
13
14
|
|
14
15
|
def _step(self, group):
|
15
16
|
eps = group['eps']
|
heavyball/foreach_sfadamw.py
CHANGED
@@ -6,12 +6,12 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
|
|
6
6
|
|
7
7
|
class ForeachSFAdamW(ScheduleFree):
|
8
8
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, foreach=
|
9
|
+
weight_lr_power=2.0, foreach: bool = True):
|
10
10
|
|
11
11
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
12
12
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
13
13
|
foreach=foreach)
|
14
|
-
super().__init__(params, defaults)
|
14
|
+
super().__init__(params, defaults, foreach)
|
15
15
|
|
16
16
|
def _step(self, group):
|
17
17
|
eps = group['eps']
|
@@ -48,7 +48,7 @@ class ForeachSFAdamW(ScheduleFree):
|
|
48
48
|
torch._foreach_add_(grad, y, alpha=decay)
|
49
49
|
|
50
50
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
51
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0],
|
52
|
-
|
51
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['betas'][0], y, z,
|
52
|
+
grad, group['r'], k + 1)
|
53
53
|
|
54
54
|
group['k'] = k + 1
|
heavyball/foreach_soap.py
CHANGED
@@ -26,12 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False
|
29
|
+
split: bool = False,
|
30
|
+
foreach: bool = True):
|
30
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
31
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
32
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
33
34
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split}
|
34
|
-
super().__init__(params, defaults)
|
35
|
+
super().__init__(params, defaults, foreach)
|
35
36
|
self._data_format = data_format
|
36
37
|
|
37
38
|
def _step(self, group):
|
@@ -59,7 +60,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
59
60
|
vals.append((p, g, grad_projected, exp_avg, exp_avg_sq))
|
60
61
|
|
61
62
|
if not vals:
|
62
|
-
return
|
63
|
+
return
|
63
64
|
|
64
65
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
65
66
|
beta1, beta2 = group["betas"]
|
heavyball/p_adam.py
CHANGED
@@ -38,7 +38,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
38
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
-
store_triu_as_line: bool = True
|
41
|
+
store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= weight_decay:
|
@@ -60,7 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
62
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
62
63
|
split=split, store_triu_as_line=store_triu_as_line)
|
63
|
-
super().__init__(params, defaults)
|
64
|
+
super().__init__(params, defaults, foreach)
|
64
65
|
|
65
66
|
self._prob_step = 0
|
66
67
|
|
@@ -90,7 +91,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
90
91
|
state['exp_avg'] = torch.zeros_like(g)
|
91
92
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
92
93
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
93
|
-
|
94
|
+
min_ndim_triangular, memory_save_mode, dtype=g.dtype)
|
94
95
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
95
96
|
|
96
97
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
@@ -6,13 +6,14 @@ from .utils import schedule_free_, warmup, ScheduleFree, exp_avg_sq_, beta_debia
|
|
6
6
|
|
7
7
|
class PaLMForeachSFAdamW(ScheduleFree):
|
8
8
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
9
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8
|
9
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8,
|
10
|
+
foreach: bool = True):
|
10
11
|
if betas[0] is not None:
|
11
12
|
beta = betas[0]
|
12
13
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
13
14
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
14
15
|
beta2_scale=beta2_scale)
|
15
|
-
super().__init__(params, defaults)
|
16
|
+
super().__init__(params, defaults, foreach)
|
16
17
|
|
17
18
|
def _step(self, group):
|
18
19
|
eps = group['eps']
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -32,7 +32,8 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
32
32
|
max_precond_dim: int = 2048, #
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False
|
35
|
+
beta2_scale: float = 0.8, split: bool = False,
|
36
|
+
foreach: bool = True):
|
36
37
|
if betas[0] is not None:
|
37
38
|
beta = betas[0]
|
38
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -40,7 +41,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
40
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
41
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
42
43
|
'split': split}
|
43
|
-
super().__init__(params, defaults)
|
44
|
+
super().__init__(params, defaults, foreach)
|
44
45
|
self._data_format = data_format
|
45
46
|
|
46
47
|
def _step(self, group):
|
@@ -27,13 +27,14 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False,
|
31
|
+
foreach: bool = True):
|
31
32
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
33
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
34
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
34
35
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
35
36
|
'split': split}
|
36
|
-
super().__init__(params, defaults)
|
37
|
+
super().__init__(params, defaults, foreach)
|
37
38
|
self._data_format = data_format
|
38
39
|
self.rng = random.Random(0x120983109)
|
39
40
|
|
@@ -32,7 +32,8 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
32
32
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
33
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False
|
35
|
+
precond_scheduler=(1 / 3, 9), betas=(None, None), beta2_scale: float = 0.8, split: bool = False,
|
36
|
+
foreach: bool = True):
|
36
37
|
if betas[0] is not None:
|
37
38
|
beta = betas[0]
|
38
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -40,7 +41,7 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
40
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
41
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'precond_scheduler': precond_scheduler,
|
42
43
|
'beta2_scale': beta2_scale, 'split': split}
|
43
|
-
super().__init__(params, defaults)
|
44
|
+
super().__init__(params, defaults, foreach)
|
44
45
|
self._data_format = data_format
|
45
46
|
self.rng = random.Random(0x120983109)
|
46
47
|
|
@@ -41,7 +41,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
41
41
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
42
42
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
43
43
|
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
|
44
|
-
betas=(None, None), split: bool = False):
|
44
|
+
betas=(None, None), split: bool = False, foreach: bool = True):
|
45
45
|
if betas[0] is not None:
|
46
46
|
beta = betas[0]
|
47
47
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -50,7 +50,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
50
50
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
51
51
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1, 'weight_sum': 0,
|
52
52
|
'gradient_clip_val': gradient_clip_val, 'precond_scheduler': precond_scheduler, 'split': split}
|
53
|
-
super().__init__(params, defaults)
|
53
|
+
super().__init__(params, defaults, foreach)
|
54
54
|
self._data_format = data_format
|
55
55
|
self.rng = random.Random(0x120983109)
|
56
56
|
|
@@ -59,7 +59,7 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
59
59
|
max_precond_dim = group['max_precond_dim']
|
60
60
|
precondition_1d = group['precondition_1d']
|
61
61
|
|
62
|
-
step = group['step'] = group.get("step",
|
62
|
+
step = group['step'] = group.get("step", 0) + 1
|
63
63
|
|
64
64
|
for p in group["params"]:
|
65
65
|
if p.grad is None:
|
heavyball/psgd_kron.py
CHANGED
@@ -38,7 +38,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True
|
41
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
+
foreach: bool = True):
|
42
43
|
if not 0.0 <= lr:
|
43
44
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
45
|
if not 0.0 <= beta < 1.0:
|
@@ -60,7 +61,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
60
61
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
62
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
63
|
store_triu_as_line=store_triu_as_line)
|
63
|
-
super().__init__(params, defaults)
|
64
|
+
super().__init__(params, defaults, foreach)
|
64
65
|
|
65
66
|
self._prob_step = 0
|
66
67
|
|
@@ -114,7 +115,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
114
115
|
|
115
116
|
if do_update:
|
116
117
|
self.balance([g], [q])
|
117
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
118
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
119
|
+
[q_orig] if store_triu_as_line else None)
|
118
120
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
119
121
|
|
120
122
|
grad_list = self.clip_fn(grad_list)
|
heavyball/pure_psgd.py
CHANGED
@@ -36,7 +36,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
36
36
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
37
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
38
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
39
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True
|
39
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
40
|
+
foreach: bool = True):
|
40
41
|
if not 0.0 <= lr:
|
41
42
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
43
|
if not 0.0 <= weight_decay:
|
@@ -56,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
56
57
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
58
|
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
58
59
|
store_triu_as_line=store_triu_as_line)
|
59
|
-
super().__init__(params, defaults)
|
60
|
+
super().__init__(params, defaults, foreach)
|
60
61
|
|
61
62
|
self._prob_step = 0
|
62
63
|
|
@@ -33,7 +33,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
33
33
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
34
34
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
35
35
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
36
|
-
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False
|
36
|
+
weight_lr_power=2.0, gradient_clip_val: float = 0.1, betas=(None, None), split: bool = False,
|
37
|
+
foreach: bool = True):
|
37
38
|
if betas[0] is not None:
|
38
39
|
beta = betas[0]
|
39
40
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -42,7 +43,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
42
43
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'r': r,
|
43
44
|
'weight_lr_power': weight_lr_power, 'train_mode': True, 'step': -1,
|
44
45
|
'gradient_clip_val': gradient_clip_val, 'weight_sum': 0, 'split': split}
|
45
|
-
super().__init__(params, defaults)
|
46
|
+
super().__init__(params, defaults, foreach)
|
46
47
|
self._data_format = data_format
|
47
48
|
self.rng = random.Random(0x120983109)
|
48
49
|
|
@@ -51,7 +52,7 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
51
52
|
max_precond_dim = group['max_precond_dim']
|
52
53
|
precondition_1d = group['precondition_1d']
|
53
54
|
|
54
|
-
step = group['step'] = group.get("step",
|
55
|
+
step = group['step'] = group.get("step", 0) + 1
|
55
56
|
|
56
57
|
for p in group["params"]:
|
57
58
|
if p.grad is None:
|
heavyball/utils.py
CHANGED
@@ -383,8 +383,25 @@ def project(grad, Q, back: bool):
|
|
383
383
|
|
384
384
|
|
385
385
|
class StatefulOptimizer(torch.optim.Optimizer):
|
386
|
+
def __init__(self, params, defaults, foreach: bool = True):
|
387
|
+
super().__init__(params, {**defaults, 'foreach': foreach})
|
388
|
+
self.fake_groups = {}
|
389
|
+
|
390
|
+
def key(self, param: torch.Tensor):
|
391
|
+
return (param.data_ptr(), tuple(param.shape))
|
392
|
+
|
393
|
+
def get_groups(self, group):
|
394
|
+
if group['foreach']:
|
395
|
+
return [group]
|
396
|
+
|
397
|
+
for p in group['params']:
|
398
|
+
if self.key(p) not in self.fake_groups:
|
399
|
+
self.fake_groups[self.key(p)] = {**group, 'params': [p]}
|
400
|
+
|
401
|
+
return [self.fake_groups[self.key(p)] for p in group['params']]
|
402
|
+
|
386
403
|
def state_(self, arg: torch.Tensor):
|
387
|
-
return self.state[
|
404
|
+
return self.state[self.key(arg)]
|
388
405
|
|
389
406
|
def state_size(self) -> int:
|
390
407
|
total_bytes = 0
|
@@ -409,8 +426,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
409
426
|
with torch.enable_grad():
|
410
427
|
loss = closure()
|
411
428
|
with torch.no_grad():
|
412
|
-
for
|
413
|
-
self.
|
429
|
+
for top_group in self.param_groups:
|
430
|
+
for group in self.get_groups(top_group):
|
431
|
+
self._step(group)
|
414
432
|
return loss
|
415
433
|
|
416
434
|
|
@@ -754,8 +772,8 @@ def update_triu_(q_state, materialised):
|
|
754
772
|
|
755
773
|
|
756
774
|
class PSGDBase(StatefulOptimizer):
|
757
|
-
def __init__(self, parameters, groups):
|
758
|
-
super().__init__(parameters, groups)
|
775
|
+
def __init__(self, parameters, groups, foreach: bool = True):
|
776
|
+
super().__init__(parameters, groups, foreach)
|
759
777
|
self.rng = random.Random(0x1923213)
|
760
778
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
761
779
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.16.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -39,12 +39,14 @@ recommended experimental optimizer is `ForeachPSGDKron`.
|
|
39
39
|
|
40
40
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
41
41
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
42
|
-
* **Foreach**: Fast multi-tensor application
|
42
|
+
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
43
43
|
* **PaLM Beta2**: Fast initial
|
44
44
|
convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
|
45
45
|
* **ScheduleFree**: No learning rate schedule, but better convergence
|
46
46
|
* [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
|
47
47
|
better step-per-second in late convergence (explained below)
|
48
|
+
* **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
|
49
|
+
bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
|
48
50
|
|
49
51
|
## Getting started
|
50
52
|
|
@@ -0,0 +1,23 @@
|
|
1
|
+
heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
|
2
|
+
heavyball/cached_psgd_kron.py,sha256=vJuy639G-_ZLSRX3goSFMXALv-ucYjrxaEtpj0IHo-M,6802
|
3
|
+
heavyball/delayed_psgd.py,sha256=sbwgAed5gmQpHNTPvuE7Si-gB-s0NVvN4d-4rNUJj4c,5893
|
4
|
+
heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
|
5
|
+
heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
|
6
|
+
heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
|
7
|
+
heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
|
8
|
+
heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
|
9
|
+
heavyball/p_adam.py,sha256=aCu4Qn0eHJETHuCGrfNKp2aygKk2ZoNQyxut3Vcqmoc,6112
|
10
|
+
heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
|
11
|
+
heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
|
12
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
|
13
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
|
14
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
|
15
|
+
heavyball/psgd_kron.py,sha256=iWTAViuzxTodtQGZnkLsEXrLG8tNU-BQB3KkTYAVcX4,5874
|
16
|
+
heavyball/pure_psgd.py,sha256=EuCPNM8TX13cOop-mvvBFh6Uo1UjD1vsE053hvil92Q,5136
|
17
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
18
|
+
heavyball/utils.py,sha256=z6taEvpgszKTrscqgowKYqb0xIVpBDVDBNGgvTE4Pb8,28484
|
19
|
+
heavyball-0.16.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
20
|
+
heavyball-0.16.0.dist-info/METADATA,sha256=yjpldOTN2rXN2-KG7R9ytuyBfmSCDpznZeRuziANChE,11941
|
21
|
+
heavyball-0.16.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
22
|
+
heavyball-0.16.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
23
|
+
heavyball-0.16.0.dist-info/RECORD,,
|
@@ -1,23 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=KbT0GMU0DKqZxq9laCrD7XgiqS9yxC1W52zhte5kjKs,2054
|
2
|
-
heavyball/cached_psgd_kron.py,sha256=mXDtxq2WJST_aUJhrLr_xCCXSFaDvD5gCTSEveBUtac,6754
|
3
|
-
heavyball/delayed_psgd.py,sha256=dN3NW1jmjxmUkgqxPwUVrqLY8nnBOFp4TVtJ_BhPDR4,5814
|
4
|
-
heavyball/foreach_adamw.py,sha256=NSzoIgNm7eavzbJgkAF0k7TUEnWAgOpt9-4juIFoaSA,1729
|
5
|
-
heavyball/foreach_adopt.py,sha256=WA07m5jocLfb1GPU8s6mJ2PteS-03ronkKm-VJrAm5I,1863
|
6
|
-
heavyball/foreach_laprop.py,sha256=mE2NDGX9XgvRhsewcWnk_-FulZPqGA65ejYF_9-A1Xk,1768
|
7
|
-
heavyball/foreach_sfadamw.py,sha256=ussHfPd99u3RTfMrCuu5oIbwNFLXK19wO1Fbz3JShlc,2097
|
8
|
-
heavyball/foreach_soap.py,sha256=WWvssYKg607uoEJHftp8ag8mtKSKSeHrT0QTgqBucVg,4587
|
9
|
-
heavyball/p_adam.py,sha256=ms7BoMHu3jKGsuztUeECrsXufGAwBpqGsxgZ5LBXLQg,6073
|
10
|
-
heavyball/palm_foreach_sfadamw.py,sha256=wjUb_fNZNUmzWXyKvwB0unP9lvNMmaYSQo5YoeS5cj0,2200
|
11
|
-
heavyball/palm_foreach_soap.py,sha256=2Sb4hUHQeexJcCgjHeQM_ENkZ6lG1DVxW72ryrvR6iY,5890
|
12
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=bHsDyh-UvHpHjumjqqy0PePoR1ZMsJV6o5wWvpLAA04,4815
|
13
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=myLTJNQKLtZ3Xi3MVTB-RYtx_XeMRJw5CIMJW75ndUY,6163
|
14
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=xeNWetBzBEYqfOSzl98aAVJsHk43QkrUUhHH_YD_mS4,6740
|
15
|
-
heavyball/psgd_kron.py,sha256=rMG5UPEgyfQs_n1MHSEicekVDpbbIzinlL8akEyY918,5795
|
16
|
-
heavyball/pure_psgd.py,sha256=LLVJhUAb04hgAmT3BTz_faswwQEQUkLhm_VwGQmbBUo,5088
|
17
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=w0P7lMmoijTpL9V7NwOHcNBFJQ7S1TS9aCiwPhY2yVw,6319
|
18
|
-
heavyball/utils.py,sha256=PWmwjZPL4oxMjK79a5R1e7JHykphNi5GdpYqO_xmmFU,27829
|
19
|
-
heavyball-0.15.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
20
|
-
heavyball-0.15.1.dist-info/METADATA,sha256=0wImMJNYM-Zg0akh9hRf7X8ofVW6zlmpyDGgAkK5GFA,11667
|
21
|
-
heavyball-0.15.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
22
|
-
heavyball-0.15.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
23
|
-
heavyball-0.15.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|