heavyball 0.21.8__tar.gz → 0.23.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.21.8 → heavyball-0.23.0}/PKG-INFO +2 -2
- {heavyball-0.21.8 → heavyball-0.23.0}/README.md +1 -1
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/__init__.py +6 -5
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/cached_delayed_psgd_kron.py +6 -5
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/cached_psgd_kron.py +7 -5
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/delayed_psgd.py +14 -11
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/foreach_adamw.py +14 -7
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/foreach_adopt.py +11 -6
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/foreach_laprop.py +12 -6
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/foreach_sfadamw.py +10 -3
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/foreach_soap.py +10 -8
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/p_adam.py +11 -9
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/palm_foreach_sfadamw.py +11 -3
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/palm_foreach_soap.py +8 -9
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/precond_schedule_foreach_soap.py +10 -8
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/precond_schedule_palm_foreach_soap.py +9 -9
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/precond_schedule_sfpsoap.py +10 -5
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/psgd_kron.py +9 -6
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/pure_psgd.py +11 -7
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/schedule_free_palm_foreach_soap.py +13 -5
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball/utils.py +171 -106
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball.egg-info/PKG-INFO +2 -2
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball.egg-info/SOURCES.txt +2 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/setup.py +1 -1
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_bf16_params.py +4 -14
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_bf16_q.py +0 -8
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_bf16_storage.py +0 -6
- heavyball-0.23.0/test/test_caution.py +41 -0
- heavyball-0.23.0/test/test_mars.py +45 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/LICENSE +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/setup.cfg +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_closure.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_ema.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_foreach.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_memory.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_merge.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_no_grad.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_psgd.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_soap.py +0 -0
- {heavyball-0.21.8 → heavyball-0.23.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.23.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-
|
11
|
+
Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
|
1
2
|
from .cached_psgd_kron import ForeachCachedPSGDKron
|
2
3
|
from .delayed_psgd import ForeachDelayedPSGD
|
3
4
|
from .foreach_adamw import ForeachAdamW
|
@@ -14,7 +15,6 @@ from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
|
14
15
|
from .psgd_kron import ForeachPSGDKron
|
15
16
|
from .pure_psgd import ForeachPurePSGD
|
16
17
|
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
17
|
-
from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
|
18
18
|
|
19
19
|
PalmForEachSoap = PaLMForeachSOAP
|
20
20
|
|
@@ -39,7 +39,8 @@ CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
|
39
39
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
40
40
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
41
41
|
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
42
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
|
43
|
-
|
44
|
-
'
|
45
|
-
'
|
42
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron',
|
43
|
+
#
|
44
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
|
45
|
+
'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
|
46
|
+
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP']
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
from heavyball.utils import min_dtype, precond_grad_cached_
|
11
11
|
|
12
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase,
|
12
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
13
|
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
14
14
|
|
15
15
|
|
@@ -43,7 +43,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
43
43
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
44
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
45
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
46
|
-
storage_dtype: str = 'float32',
|
46
|
+
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
47
|
+
#
|
47
48
|
# expert parameters
|
48
49
|
precond_init_scale=1.0, precond_lr=0.1):
|
49
50
|
if not 0.0 <= lr:
|
@@ -61,7 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
61
62
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
62
63
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
63
64
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
64
|
-
storage_dtype=storage_dtype)
|
65
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
|
65
66
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
66
67
|
|
67
68
|
def _step(self, group):
|
@@ -81,7 +82,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
81
82
|
|
82
83
|
vals = []
|
83
84
|
|
84
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
85
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
85
86
|
state = self.state_(p)
|
86
87
|
|
87
88
|
if 'Q' not in state:
|
@@ -120,7 +121,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
120
121
|
q_orig = Q_list.pop(0)
|
121
122
|
ea = exp_avg_list.pop(0)
|
122
123
|
|
123
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
|
124
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn, group['caution'], g)
|
124
125
|
|
125
126
|
if should_update:
|
126
127
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
@@ -8,7 +8,7 @@ from typing import Optional
|
|
8
8
|
|
9
9
|
import torch
|
10
10
|
|
11
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase,
|
11
|
+
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
12
|
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
13
|
|
14
14
|
|
@@ -40,7 +40,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
42
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
-
storage_dtype: str = 'float32',
|
43
|
+
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
+
#
|
44
45
|
# expert parameters
|
45
46
|
precond_init_scale=1.0, precond_lr=0.1):
|
46
47
|
if not 0.0 <= lr:
|
@@ -58,7 +59,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
58
59
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
59
60
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
60
61
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
-
storage_dtype=storage_dtype)
|
62
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
|
62
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
64
|
|
64
65
|
def _step(self, group):
|
@@ -78,7 +79,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
78
79
|
|
79
80
|
vals = []
|
80
81
|
|
81
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
82
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
82
83
|
state = self.state_(p)
|
83
84
|
|
84
85
|
if 'Q' not in state:
|
@@ -128,4 +129,5 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
129
|
else:
|
129
130
|
torch.mul(q_.conj(), q_, out=c_)
|
130
131
|
|
131
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn
|
132
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn,
|
133
|
+
group['caution'], g)
|
@@ -5,17 +5,16 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import stochastic_lerp_, beta_debias
|
8
|
+
from heavyball.utils import stochastic_lerp_, beta_debias, stochastic_add_
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
|
11
|
+
triu_to_line, line_to_triu, promote,_compilable_update_
|
12
|
+
|
12
13
|
|
13
|
-
# TODO: E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Error while creating guard:
|
14
|
-
# E1123 00:51:55.423000 159394 site-packages/torch/_guards.py:283] [5/0] Name: "G['psgd_precond_grad'].__defaults__[0]"
|
15
14
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
16
|
-
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn):
|
17
|
-
new = psgd_precond_grad(
|
18
|
-
|
15
|
+
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn, caution, grad):
|
16
|
+
new = psgd_precond_grad(False, exprs, ea, *q)
|
17
|
+
_compilable_update_([p], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
|
19
18
|
|
20
19
|
|
21
20
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -46,7 +45,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
46
45
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
47
46
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
48
47
|
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
49
|
-
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
|
48
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
|
49
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, #
|
50
50
|
# expert parameters
|
51
51
|
precond_init_scale=1.0, precond_lr=0.1):
|
52
52
|
if not 0.0 <= lr:
|
@@ -63,7 +63,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
63
63
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
64
64
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
65
65
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
66
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
66
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
67
|
+
storage_dtype=storage_dtype,
|
68
|
+
caution=caution, mars_gamma=mars_gamma, mars=mars)
|
67
69
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
68
70
|
|
69
71
|
def _step(self, group):
|
@@ -83,7 +85,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
83
85
|
|
84
86
|
vals = []
|
85
87
|
|
86
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
88
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
87
89
|
state = self.state_(p)
|
88
90
|
|
89
91
|
if 'Q' not in state:
|
@@ -112,7 +114,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
112
114
|
q_orig = Q_list.pop(0)
|
113
115
|
ea = exp_avg_list.pop(0)
|
114
116
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
115
|
-
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn
|
117
|
+
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"][-1], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
|
118
|
+
g)
|
116
119
|
if should_update:
|
117
120
|
q32 = [promote(q_) for q_ in q]
|
118
121
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
@@ -1,18 +1,19 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.optim
|
3
|
-
|
4
3
|
from heavyball.utils import copy_stochastic_list_
|
4
|
+
|
5
5
|
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
10
10
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
11
|
|
12
12
|
torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
|
13
13
|
denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
|
14
14
|
|
15
|
-
update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l)
|
15
|
+
update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l), caution=caution,
|
16
|
+
grad=g32)
|
16
17
|
|
17
18
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
18
19
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
@@ -20,9 +21,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
20
21
|
|
21
22
|
class ForeachAdamW(StatefulOptimizer):
|
22
23
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
23
|
-
foreach: bool = True, storage_dtype: str = 'float32'
|
24
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
25
|
+
mars_gamma: float = 0.0025):
|
24
26
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
25
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype
|
27
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
28
|
+
mars_gamma=mars_gamma)
|
26
29
|
super().__init__(params, defaults, foreach)
|
27
30
|
|
28
31
|
def _step(self, group):
|
@@ -48,9 +51,13 @@ class ForeachAdamW(StatefulOptimizer):
|
|
48
51
|
y, grad, exp_avg_sq, exp_avg = zip(
|
49
52
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
50
53
|
|
54
|
+
if group['mars']:
|
55
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
56
|
+
|
51
57
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
52
58
|
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
53
59
|
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
54
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay
|
60
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
|
61
|
+
group['caution'])
|
55
62
|
|
56
63
|
group['k'] = k + 1
|
@@ -5,10 +5,10 @@ from heavyball.utils import copy_stochastic_list_
|
|
5
5
|
from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
10
10
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
|
-
update_param_(y, exp_avg, lr, decay)
|
11
|
+
update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
|
12
12
|
|
13
13
|
beta1 = beta_debias(beta1, step)
|
14
14
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
@@ -27,9 +27,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
27
27
|
class ForeachADOPT(StatefulOptimizer):
|
28
28
|
|
29
29
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
30
|
-
foreach: bool = True, storage_dtype: str = 'float32'
|
30
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
31
|
+
mars_gamma: float = 0.0025):
|
31
32
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
32
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype
|
33
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
34
|
+
mars_gamma=mars_gamma)
|
33
35
|
super().__init__(params, defaults, foreach)
|
34
36
|
|
35
37
|
def _step(self, group):
|
@@ -57,11 +59,14 @@ class ForeachADOPT(StatefulOptimizer):
|
|
57
59
|
|
58
60
|
group['k'] = k + 1
|
59
61
|
|
62
|
+
if group['mars']:
|
63
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
64
|
+
|
60
65
|
if k > 1:
|
61
66
|
lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
|
62
67
|
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
63
68
|
k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
64
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay)
|
69
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay, group['caution'])
|
65
70
|
return
|
66
71
|
|
67
72
|
grad = [promote(g) for g in grad]
|
@@ -4,8 +4,8 @@ import torch.optim
|
|
4
4
|
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_
|
5
5
|
|
6
6
|
|
7
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
8
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay):
|
7
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
8
|
+
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
9
9
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
10
10
|
|
11
11
|
denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
|
@@ -14,7 +14,7 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
14
14
|
torch._foreach_mul_(exp_avg32, beta1)
|
15
15
|
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
16
16
|
|
17
|
-
update_param_(y, exp_avg32, lr, decay)
|
17
|
+
update_param_(y, exp_avg32, lr, decay, caution=caution, grad=g32)
|
18
18
|
|
19
19
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
20
20
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
@@ -23,9 +23,11 @@ def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps,
|
|
23
23
|
class ForeachLaProp(StatefulOptimizer):
|
24
24
|
|
25
25
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
26
|
-
foreach: bool = True, storage_dtype: str = 'float32'
|
26
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
27
|
+
mars_gamma: float = 0.0025):
|
27
28
|
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
28
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype
|
29
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
30
|
+
mars_gamma=mars_gamma)
|
29
31
|
super().__init__(params, defaults, foreach)
|
30
32
|
|
31
33
|
def _step(self, group):
|
@@ -52,10 +54,14 @@ class ForeachLaProp(StatefulOptimizer):
|
|
52
54
|
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
|
53
55
|
for p in active_p])
|
54
56
|
|
57
|
+
if group['mars']:
|
58
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
59
|
+
|
55
60
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
56
61
|
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
57
62
|
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
58
63
|
|
59
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay
|
64
|
+
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
|
65
|
+
group['caution'])
|
60
66
|
|
61
67
|
group['k'] = k + 1
|
@@ -5,7 +5,7 @@ from heavyball.utils import get_ckp1, copy_stochastic_list_
|
|
5
5
|
from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, promote, _compilable_schedule_free_
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
9
|
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
10
10
|
old_debiased2 = beta_debias(beta2, step)
|
11
11
|
|
@@ -21,13 +21,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
21
21
|
|
22
22
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
23
23
|
|
24
|
+
|
24
25
|
class ForeachSFAdamW(ScheduleFree):
|
25
26
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
26
|
-
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32'
|
27
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
28
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
29
|
+
|
30
|
+
assert not caution, "Caution not implemented for SFAdamW"
|
27
31
|
|
28
32
|
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
29
33
|
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
30
|
-
foreach=foreach, storage_dtype=storage_dtype)
|
34
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma)
|
31
35
|
super().__init__(params, defaults, foreach)
|
32
36
|
|
33
37
|
def _step(self, group):
|
@@ -53,6 +57,9 @@ class ForeachSFAdamW(ScheduleFree):
|
|
53
57
|
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
54
58
|
for p in active_p])
|
55
59
|
|
60
|
+
if group['mars']:
|
61
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
62
|
+
|
56
63
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
57
64
|
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
58
65
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
|
4
|
+
StatefulOptimizer, exp_avg_
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
@@ -26,11 +26,13 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False, foreach: bool = True
|
29
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
30
|
+
mars_gamma: float = 0.0025):
|
30
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
31
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
32
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
33
|
-
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split
|
34
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
35
|
+
'caution': caution, 'mars_gamma': mars_gamma}
|
34
36
|
super().__init__(params, defaults, foreach)
|
35
37
|
self._data_format = data_format
|
36
38
|
|
@@ -41,7 +43,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
41
43
|
max_precond_dim = group['max_precond_dim']
|
42
44
|
precondition_1d = group['precondition_1d']
|
43
45
|
|
44
|
-
for p, g in split_p_and_g_in_group(group):
|
46
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['betas'][0]):
|
45
47
|
state = self.state_(p)
|
46
48
|
step = state['step'] = state.get("step", -1) + 1
|
47
49
|
|
@@ -71,6 +73,8 @@ class ForeachSOAP(StatefulOptimizer):
|
|
71
73
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
72
74
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
73
75
|
|
76
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
77
|
+
|
74
78
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
75
79
|
state = self.state_(p)
|
76
80
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
@@ -80,11 +84,9 @@ class ForeachSOAP(StatefulOptimizer):
|
|
80
84
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
81
85
|
# to the original space
|
82
86
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
83
|
-
|
87
|
+
precond = project(exp_avg_projected / d, state['Q'], True)
|
84
88
|
|
85
89
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
86
90
|
step > 0 and step % group['precondition_frequency'] == 0)
|
87
91
|
|
88
|
-
|
89
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
90
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
92
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
|
9
8
|
from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
|
9
|
+
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
|
11
|
-
|
11
|
+
promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachPaLMPAdam(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True,
|
42
|
+
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
43
|
+
caution: bool = False, mars_gamma: float = 0.0025, #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -57,7 +58,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
60
|
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype
|
61
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype,
|
62
|
+
mars=mars, caution=caution, mars_gamma=mars_gamma)
|
61
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
64
|
|
63
65
|
def _step(self, group):
|
@@ -75,7 +77,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
75
77
|
|
76
78
|
vals = []
|
77
79
|
|
78
|
-
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
80
|
+
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=group['beta']):
|
79
81
|
state = self.state_(p)
|
80
82
|
|
81
83
|
if 'Q' not in state:
|
@@ -107,13 +109,13 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
107
109
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
108
110
|
|
109
111
|
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
110
|
-
|
111
|
-
|
112
|
+
gc = g.clone() if group['caution'] else None
|
113
|
+
psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *Q)
|
114
|
+
ea = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *Q)
|
112
115
|
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
113
116
|
torch.div(ea, g, out=g)
|
114
117
|
"""
|
115
118
|
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
116
119
|
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
117
120
|
"""
|
118
|
-
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
119
|
-
|
121
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=gc)
|
@@ -5,7 +5,7 @@ from .utils import warmup, ScheduleFree, exp_avg_sq_, beta_debias, get_ckp1, pro
|
|
5
5
|
_compilable_schedule_free_, copy_stochastic_list_
|
6
6
|
|
7
7
|
|
8
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
8
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
9
9
|
def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, decay, lr):
|
10
10
|
old_debiased2 = beta_debias(beta2, step)
|
11
11
|
|
@@ -24,12 +24,17 @@ def _compilable_step_(y, grad, exp_avg_sq, z, beta1, beta2, step, ckp1, eps, dec
|
|
24
24
|
|
25
25
|
class PaLMForeachSFAdamW(ScheduleFree):
|
26
26
|
def __init__(self, params, lr=0.0025, beta=0.9, betas=(None, None), eps=1e-8, weight_decay=0, warmup_steps=0, r=0.0,
|
27
|
-
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32'
|
27
|
+
weight_lr_power=2.0, beta2_scale: float = 0.8, foreach: bool = True, storage_dtype: str = 'float32',
|
28
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025):
|
28
29
|
if betas[0] is not None:
|
29
30
|
beta = betas[0]
|
31
|
+
|
32
|
+
assert not caution, "Caution not implemented for SFAdamW"
|
33
|
+
|
30
34
|
defaults = dict(lr=lr, beta=beta, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
31
35
|
lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
32
|
-
beta2_scale=beta2_scale, storage_dtype=storage_dtype
|
36
|
+
beta2_scale=beta2_scale, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
37
|
+
mars_gamma=mars_gamma)
|
33
38
|
super().__init__(params, defaults, foreach)
|
34
39
|
|
35
40
|
def _step(self, group):
|
@@ -58,6 +63,9 @@ class PaLMForeachSFAdamW(ScheduleFree):
|
|
58
63
|
y, grad, exp_avg_sq, z = zip(*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['z']) #
|
59
64
|
for p in active_p])
|
60
65
|
|
66
|
+
if group['mars']:
|
67
|
+
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
68
|
+
|
61
69
|
lr = warmup(group['lr'], k + 1, group['warmup_steps'])
|
62
70
|
ckp1, group['weight_sum'] = get_ckp1(lr, group['weight_lr_power'], group['weight_sum'], group['r'], k + 1)
|
63
71
|
|
@@ -1,8 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
|
5
|
-
|
4
|
+
StatefulOptimizer, exp_avg_
|
6
5
|
|
7
6
|
|
8
7
|
class PaLMForeachSOAP(StatefulOptimizer):
|
@@ -33,14 +32,15 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
33
32
|
max_precond_dim: int = 2048, #
|
34
33
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
35
34
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
36
|
-
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True
|
35
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True, mars: bool = False,
|
36
|
+
caution: bool = False, mars_gamma: float = 0.0025):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
40
40
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
41
41
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
42
42
|
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'beta2_scale': beta2_scale,
|
43
|
-
'split': split}
|
43
|
+
'split': split, 'mars': mars, 'caution': caution, 'mars_gamma': mars_gamma}
|
44
44
|
super().__init__(params, defaults, foreach)
|
45
45
|
self._data_format = data_format
|
46
46
|
|
@@ -51,7 +51,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
51
51
|
max_precond_dim = group['max_precond_dim']
|
52
52
|
precondition_1d = group['precondition_1d']
|
53
53
|
|
54
|
-
for p, g in split_p_and_g_in_group(group):
|
54
|
+
for p, g in self.split_p_and_g_in_group(group, beta1=group['beta']):
|
55
55
|
state = self.state_(p)
|
56
56
|
step = state['step'] = state.get("step", -1) + 1
|
57
57
|
|
@@ -82,6 +82,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
82
82
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
83
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
84
|
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
|
+
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
85
86
|
|
86
87
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
87
88
|
state = self.state_(p)
|
@@ -92,11 +93,9 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
92
93
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
93
94
|
# to the original space
|
94
95
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
95
|
-
|
96
|
+
precond = project(exp_avg_projected / d, state['Q'], True)
|
96
97
|
|
97
98
|
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
98
99
|
step > 0 and step % group['precondition_frequency'] == 0)
|
99
100
|
|
100
|
-
|
101
|
-
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
102
|
-
update_param_(p_list, denom, step_size, group["weight_decay"])
|
101
|
+
update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])
|