heavyball 1.1.0__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.1.0 → heavyball-1.1.2}/PKG-INFO +1 -1
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball/__init__.py +40 -35
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball/chainable.py +24 -25
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball/utils.py +108 -80
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.1.0 → heavyball-1.1.2}/setup.py +1 -1
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_channels_last.py +2 -1
- {heavyball-1.1.0 → heavyball-1.1.2}/LICENSE +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/README.md +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/setup.cfg +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_bf16_params.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_bf16_q.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_bf16_storage.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_caution.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_closure.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_ema.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_foreach.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_hook.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_mars.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_memory.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_merge.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_no_grad.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_psgd.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_soap.py +0 -0
- {heavyball-1.1.0 → heavyball-1.1.2}/test/test_stochastic_updates.py +0 -0
@@ -10,9 +10,9 @@ class ForeachAdamW(C.BaseOpt):
|
|
10
10
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
11
11
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
12
12
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
13
|
-
defaults =
|
14
|
-
|
15
|
-
|
13
|
+
defaults = locals()
|
14
|
+
defaults.pop("self")
|
15
|
+
params = defaults.pop("params")
|
16
16
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
17
17
|
|
18
18
|
|
@@ -25,9 +25,9 @@ class ForeachRMSprop(C.BaseOpt):
|
|
25
25
|
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
26
26
|
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
27
27
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
28
|
-
defaults =
|
29
|
-
|
30
|
-
|
28
|
+
defaults = locals()
|
29
|
+
defaults.pop("self")
|
30
|
+
params = defaults.pop("params")
|
31
31
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
|
32
32
|
|
33
33
|
|
@@ -36,10 +36,9 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
36
36
|
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
37
37
|
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
38
38
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
39
|
-
defaults =
|
40
|
-
|
41
|
-
|
42
|
-
beta2_scale=beta2_scale)
|
39
|
+
defaults = locals()
|
40
|
+
defaults.pop("self")
|
41
|
+
params = defaults.pop("params")
|
43
42
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
|
44
43
|
C.update_by_schedule_free)
|
45
44
|
|
@@ -53,9 +52,9 @@ class ForeachADOPT(C.BaseOpt):
|
|
53
52
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
54
53
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
55
54
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
56
|
-
defaults =
|
57
|
-
|
58
|
-
|
55
|
+
defaults = locals()
|
56
|
+
defaults.pop("self")
|
57
|
+
params = defaults.pop("params")
|
59
58
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
59
|
|
61
60
|
|
@@ -65,9 +64,9 @@ class ForeachMuon(C.BaseOpt):
|
|
65
64
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
65
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
|
67
66
|
nesterov: bool = True):
|
68
|
-
defaults =
|
69
|
-
|
70
|
-
|
67
|
+
defaults = locals()
|
68
|
+
defaults.pop("self")
|
69
|
+
params = defaults.pop("params")
|
71
70
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
72
71
|
C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
|
73
72
|
|
@@ -77,12 +76,24 @@ class ForeachLaProp(C.BaseOpt):
|
|
77
76
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
78
77
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
79
78
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
80
|
-
defaults =
|
81
|
-
|
82
|
-
|
79
|
+
defaults = locals()
|
80
|
+
defaults.pop("self")
|
81
|
+
params = defaults.pop("params")
|
83
82
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
84
83
|
|
85
84
|
|
85
|
+
class MuonLaProp(C.BaseOpt):
|
86
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
87
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
88
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
89
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
90
|
+
defaults = locals()
|
91
|
+
defaults.pop("self")
|
92
|
+
params = defaults.pop("params")
|
93
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
|
94
|
+
C.orthogonalize_update)
|
95
|
+
|
96
|
+
|
86
97
|
class ForeachSOAP(C.BaseOpt):
|
87
98
|
"""
|
88
99
|
ForeachSOAP
|
@@ -112,12 +123,10 @@ class ForeachSOAP(C.BaseOpt):
|
|
112
123
|
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
|
113
124
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
114
125
|
|
115
|
-
defaults =
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
|
120
|
-
'beta2_scale': beta2_scale}
|
126
|
+
defaults = locals()
|
127
|
+
defaults.pop("self")
|
128
|
+
params = defaults.pop("params")
|
129
|
+
|
121
130
|
if use_precond_schedule:
|
122
131
|
del defaults['precondition_frequency']
|
123
132
|
else:
|
@@ -161,19 +170,15 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
161
170
|
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
162
171
|
# expert parameters
|
163
172
|
precond_init_scale=1.0, precond_lr=0.1):
|
173
|
+
defaults = locals()
|
174
|
+
defaults.pop("self")
|
175
|
+
params = defaults.pop("params")
|
176
|
+
|
164
177
|
delayed = C.default(delayed, self.delayed)
|
165
178
|
cached = C.default(cached, self.cached)
|
166
179
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
167
180
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
168
181
|
|
169
|
-
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
170
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
171
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
172
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
173
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
174
|
-
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
175
|
-
stochastic_schedule=stochastic_schedule)
|
176
|
-
|
177
182
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
|
178
183
|
*(C.exp_avg,) * exp_avg_input, #
|
179
184
|
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
|
@@ -215,9 +220,9 @@ CachedPSGDKron = ForeachCachedPSGDKron
|
|
215
220
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
216
221
|
Muon = ForeachMuon
|
217
222
|
|
218
|
-
__all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
223
|
+
__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
219
224
|
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
220
|
-
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop',
|
225
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
221
226
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
222
227
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
223
228
|
"ForeachRMSprop", "ForeachMuon"]
|
@@ -140,16 +140,14 @@ class SkipUpdate(ValueError):
|
|
140
140
|
@zero_guard("exp_avg")
|
141
141
|
@no_state
|
142
142
|
def exp_avg(group, update, grad, param, exp_avg):
|
143
|
-
utils.
|
144
|
-
return exp_avg
|
143
|
+
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
145
144
|
|
146
145
|
|
147
146
|
@zero_guard("exp_avg_sq")
|
148
147
|
@no_state
|
149
148
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
150
|
-
|
151
|
-
|
152
|
-
return out
|
149
|
+
return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
|
150
|
+
group['eps'])
|
153
151
|
|
154
152
|
|
155
153
|
@zero_guard("exp_avg", "exp_avg_sq")
|
@@ -162,22 +160,21 @@ def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
162
160
|
@zero_guard("exp_avg", "exp_avg_sq")
|
163
161
|
@no_state
|
164
162
|
def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
165
|
-
utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
|
166
|
-
group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
163
|
+
utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
|
164
|
+
group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
167
165
|
raise SkipUpdate
|
168
166
|
|
169
167
|
|
170
168
|
@zero_guard("exp_avg", "exp_avg_sq")
|
171
169
|
@no_state
|
172
170
|
def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
173
|
-
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step']
|
174
|
-
group['eps'])
|
171
|
+
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'])
|
175
172
|
|
176
173
|
|
177
174
|
@zero_guard("exp_avg", "exp_avg_sq")
|
178
175
|
@no_state
|
179
176
|
def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
180
|
-
utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
|
177
|
+
utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
|
181
178
|
group['step'], group['lr'], group['weight_decay'], group['caution'])
|
182
179
|
raise SkipUpdate
|
183
180
|
|
@@ -205,7 +202,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
205
202
|
utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
|
206
203
|
raise SkipUpdate
|
207
204
|
|
208
|
-
utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
205
|
+
utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
209
206
|
group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
210
207
|
raise SkipUpdate
|
211
208
|
|
@@ -264,13 +261,13 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
264
261
|
|
265
262
|
|
266
263
|
@no_state_no_foreach
|
267
|
-
def orthogonalize_update(group, update, grad, param):
|
264
|
+
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
|
268
265
|
if update.dim() == 1:
|
269
266
|
return update
|
270
267
|
original_shape = update.shape
|
271
268
|
# doing it this way, as tmp and update are not guaranteed to share memory address or layout
|
272
269
|
tmp = update.flatten(1, -1)
|
273
|
-
utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
|
270
|
+
utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp, scale_mode)
|
274
271
|
return tmp.reshape(original_shape)
|
275
272
|
|
276
273
|
|
@@ -333,7 +330,7 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
333
330
|
|
334
331
|
def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
|
335
332
|
if cached:
|
336
|
-
return utils.precond_grad_cached_(cache_expr, update, *
|
333
|
+
return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
|
337
334
|
return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
|
338
335
|
|
339
336
|
|
@@ -350,9 +347,12 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
|
|
350
347
|
@no_state_no_foreach
|
351
348
|
def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
352
349
|
prob: Optional[callable] = None):
|
350
|
+
old = update
|
351
|
+
update = update.to(memory_format=torch.contiguous_format)
|
353
352
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
354
353
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
355
|
-
|
354
|
+
out = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
355
|
+
return torch.as_strided(out, old.shape, old.stride())
|
356
356
|
|
357
357
|
|
358
358
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
@@ -360,7 +360,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
360
360
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
361
361
|
prob: Optional[callable] = None):
|
362
362
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
363
|
-
precond = _cached_psgd_precond_grad(
|
363
|
+
precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
364
364
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
365
365
|
return precond
|
366
366
|
|
@@ -400,7 +400,7 @@ def apply_to_idx(fn, idx):
|
|
400
400
|
|
401
401
|
|
402
402
|
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
403
|
-
update = [torch.clone(g) for g in grad]
|
403
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
404
404
|
skip_update = False
|
405
405
|
for fn in fns:
|
406
406
|
try:
|
@@ -417,7 +417,6 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
417
417
|
class ChainOpt(utils.StatefulOptimizer):
|
418
418
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
419
419
|
super().__init__(params, defaults, foreach)
|
420
|
-
|
421
420
|
self.fns = tuple(fns)
|
422
421
|
|
423
422
|
def _step(self, group):
|
@@ -472,9 +471,10 @@ class BaseOpt(ChainOpt):
|
|
472
471
|
update_clipping: str_or_fn = None
|
473
472
|
palm: bool = False
|
474
473
|
auto_fuse: bool = True
|
474
|
+
compile_step: bool = False
|
475
475
|
|
476
476
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
477
|
-
palm: bool =
|
477
|
+
palm: bool = use_default, *fns):
|
478
478
|
if default(update_clipping, self.update_clipping) is None:
|
479
479
|
if fns and self.auto_fuse:
|
480
480
|
args, kwargs = None, None
|
@@ -489,6 +489,7 @@ class BaseOpt(ChainOpt):
|
|
489
489
|
else:
|
490
490
|
if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
|
491
491
|
raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
|
492
|
+
|
492
493
|
fns = tuple(fns)
|
493
494
|
|
494
495
|
if default(palm, self.palm):
|
@@ -504,9 +505,9 @@ class BaseOpt(ChainOpt):
|
|
504
505
|
class ScheduleFree(BaseOpt):
|
505
506
|
def eval(self):
|
506
507
|
for group in self.param_groups:
|
507
|
-
train_mode = group
|
508
|
+
group['train_mode'] = train_mode = not group.get('train_mode')
|
508
509
|
beta1 = utils.get_beta1(group)
|
509
|
-
if beta1 > 0 and train_mode:
|
510
|
+
if beta1 > 0 and not train_mode:
|
510
511
|
for p in group['params']:
|
511
512
|
state = self.state_(p)
|
512
513
|
if 'z' in state:
|
@@ -515,13 +516,12 @@ class ScheduleFree(BaseOpt):
|
|
515
516
|
p32 = utils.promote(p.data)
|
516
517
|
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
517
518
|
utils.copy_stochastic_(p.data, p32)
|
518
|
-
group['train_mode'] = False
|
519
519
|
|
520
520
|
def train(self):
|
521
521
|
for group in self.param_groups:
|
522
|
-
train_mode = group
|
522
|
+
group['train_mode'] = train_mode = not group.get('train_mode')
|
523
523
|
beta1 = utils.get_beta1(group)
|
524
|
-
if beta1 > 0 and
|
524
|
+
if beta1 > 0 and train_mode:
|
525
525
|
for p in group['params']:
|
526
526
|
state = self.state_(p)
|
527
527
|
if 'z' in state:
|
@@ -529,4 +529,3 @@ class ScheduleFree(BaseOpt):
|
|
529
529
|
p32 = utils.promote(p.data)
|
530
530
|
p32.lerp_(end=z, weight=1 - beta1)
|
531
531
|
utils.copy_stochastic_(p.data, p32)
|
532
|
-
group['train_mode'] = True
|
@@ -163,7 +163,7 @@ def beta_debias(beta, step):
|
|
163
163
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
164
164
|
out: List[Optional[Tensor]]):
|
165
165
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
166
|
-
torch.
|
166
|
+
s32 = torch._foreach_mul(s32, beta2)
|
167
167
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
168
168
|
denom = torch._foreach_sqrt(s32)
|
169
169
|
[d.clamp_(min=eps) for d in denom]
|
@@ -185,11 +185,11 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
185
185
|
@decorator_knowngood
|
186
186
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
187
187
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
188
|
-
torch.
|
188
|
+
s32 = torch._foreach_mul(s32, beta2)
|
189
189
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
190
190
|
denom = torch._foreach_sqrt(s32)
|
191
191
|
[d.clamp_(min=eps) for d in denom]
|
192
|
-
out = torch.
|
192
|
+
out = torch._foreach_div(g32, denom)
|
193
193
|
copy_stochastic_list_(state, s32)
|
194
194
|
copy_stochastic_list_(grad, out)
|
195
195
|
|
@@ -201,6 +201,21 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
201
201
|
return grad
|
202
202
|
|
203
203
|
|
204
|
+
@decorator_knowngood
|
205
|
+
def _compilable_exp_avg_(state, grad, beta):
|
206
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
207
|
+
s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
|
208
|
+
copy_stochastic_list_(state, s32)
|
209
|
+
copy_stochastic_list_(grad, s32)
|
210
|
+
|
211
|
+
|
212
|
+
def scale_by_exp_avg_(state, grad, beta):
|
213
|
+
state, grad = list_guard(state, grad)
|
214
|
+
beta = scalar_guard(beta, state[0])
|
215
|
+
_compilable_exp_avg_(state, grad, beta)
|
216
|
+
return grad
|
217
|
+
|
218
|
+
|
204
219
|
@decorator_knowngood
|
205
220
|
def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
|
206
221
|
p_norm = torch._foreach_norm(parameters)
|
@@ -210,7 +225,7 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
|
|
210
225
|
torch._foreach_div_(p_norm, g_norm)
|
211
226
|
torch._foreach_mul_(p_norm, clip_val)
|
212
227
|
torch._foreach_minimum_(p_norm, 1)
|
213
|
-
|
228
|
+
torch._foreach_mul_(gradients, p_norm)
|
214
229
|
|
215
230
|
|
216
231
|
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
@@ -219,7 +234,8 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
|
|
219
234
|
return gradients
|
220
235
|
parameters, gradients = list_guard(parameters, gradients)
|
221
236
|
clip_val = scalar_guard(clip_val, parameters[0])
|
222
|
-
|
237
|
+
_compilable_agc_(parameters, gradients, clip_val, minimum, eps)
|
238
|
+
return gradients
|
223
239
|
|
224
240
|
|
225
241
|
def is_compiling():
|
@@ -289,7 +305,7 @@ def ortho(x):
|
|
289
305
|
@decorator_knowngood
|
290
306
|
def _compilable_heavyball_momentum_(state, grad, beta):
|
291
307
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
292
|
-
torch.
|
308
|
+
s32 = torch._foreach_mul(s32, beta)
|
293
309
|
torch._foreach_add_(s32, g32)
|
294
310
|
copy_stochastic_list_(state, s32)
|
295
311
|
copy_stochastic_list_(grad, s32)
|
@@ -298,7 +314,7 @@ def _compilable_heavyball_momentum_(state, grad, beta):
|
|
298
314
|
@decorator_knowngood
|
299
315
|
def _compilable_nesterov_momentum_(state, grad, beta):
|
300
316
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
301
|
-
torch.
|
317
|
+
s32 = torch._foreach_mul(s32, beta)
|
302
318
|
torch._foreach_add_(s32, g32)
|
303
319
|
[g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
|
304
320
|
copy_stochastic_list_(state, s32)
|
@@ -319,17 +335,27 @@ def nesterov_momentum(state, grad, beta):
|
|
319
335
|
return grad
|
320
336
|
|
321
337
|
|
338
|
+
# mode in ("newtonschulz", "qr", "svd")
|
339
|
+
# scale_mode in ("none", "scale", "graft")
|
322
340
|
@decorator_knowngood
|
323
|
-
def inplace_orthogonal_(x, mode, out):
|
324
|
-
if mode == '
|
325
|
-
y =
|
341
|
+
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
342
|
+
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
343
|
+
y = zeropower_via_newtonschulz5(x, 5)
|
344
|
+
elif mode == 'qr':
|
345
|
+
y = torch.linalg.qr(promote(x)).Q
|
326
346
|
elif mode == 'svd':
|
327
|
-
u, s, v = torch.linalg.svd(x)
|
347
|
+
u, s, v = torch.linalg.svd(promote(x))
|
328
348
|
y = u @ v.T
|
329
|
-
elif mode == 'newtonschulz':
|
330
|
-
y = zeropower_via_newtonschulz5(x, 5)
|
331
349
|
else:
|
332
350
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
351
|
+
if scale_mode == "none":
|
352
|
+
pass
|
353
|
+
elif scale_mode == "scale":
|
354
|
+
y *= max(1, x.size(0) / x.size(1)) ** 0.5
|
355
|
+
elif scale_mode == "graft":
|
356
|
+
y *= x.norm() / y.norm().clamp_(min=1e-6)
|
357
|
+
else:
|
358
|
+
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
333
359
|
set_(out, y)
|
334
360
|
|
335
361
|
|
@@ -363,7 +389,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
363
389
|
est_eig = torch.einsum('ij,ij->j', o, tmp)
|
364
390
|
sort_idx = torch.argsort(est_eig, descending=True)
|
365
391
|
indices.append(sort_idx)
|
366
|
-
inplace_orthogonal_(tmp[:, sort_idx], q)
|
392
|
+
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
|
367
393
|
|
368
394
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
369
395
|
for i, ind in enumerate(indices))
|
@@ -422,8 +448,7 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
422
448
|
for x_, y_ in zip(x, y):
|
423
449
|
x32 = promote(x_)
|
424
450
|
y32 = promote(y_)
|
425
|
-
x32.
|
426
|
-
copy_stochastic_(x_, x32)
|
451
|
+
copy_stochastic_(x_, x32.lerp(y32, a))
|
427
452
|
|
428
453
|
|
429
454
|
def get_beta1(group):
|
@@ -484,7 +509,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
484
509
|
for x_, y_ in zip(x, y):
|
485
510
|
x32 = promote(x_)
|
486
511
|
y32 = promote(y_)
|
487
|
-
x32.add_(y32, alpha=alpha)
|
512
|
+
x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
|
488
513
|
copy_stochastic_(x_, x32)
|
489
514
|
|
490
515
|
|
@@ -506,7 +531,7 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
506
531
|
g0 = einsum_base[:grad.dim()]
|
507
532
|
g1 = g0.replace(b, b.upper())
|
508
533
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
509
|
-
GG[idx].lerp_(
|
534
|
+
GG[idx].lerp_(outer_product, 1 - beta)
|
510
535
|
|
511
536
|
|
512
537
|
def promote(x):
|
@@ -571,7 +596,8 @@ def project(grad, Q, back: bool):
|
|
571
596
|
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
|
572
597
|
if preconditioners:
|
573
598
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
574
|
-
|
599
|
+
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
|
600
|
+
grad = out.to(grad.dtype)
|
575
601
|
return grad
|
576
602
|
|
577
603
|
|
@@ -724,20 +750,26 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
|
724
750
|
copy_stochastic_(t, s)
|
725
751
|
|
726
752
|
|
753
|
+
def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
|
754
|
+
ea32 = list(map(promote, state))
|
755
|
+
grad = list(map(promote, grad))
|
756
|
+
|
757
|
+
ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
|
758
|
+
copy_stochastic_list_(state, ea32)
|
759
|
+
return ea32
|
760
|
+
|
761
|
+
|
727
762
|
@decorator_knowngood
|
728
763
|
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
729
764
|
step: Tensor):
|
730
765
|
beta1 = beta_debias(beta1, step)
|
731
766
|
beta2 = beta_debias(beta2, step)
|
732
767
|
|
733
|
-
g32
|
768
|
+
g32 = list(map(promote, grad))
|
734
769
|
|
735
|
-
|
736
|
-
denom = exp_avg_sq_(
|
770
|
+
exp_avg32 = _lerp32(exp_avg, g32, beta1)
|
771
|
+
denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
|
737
772
|
u32 = torch._foreach_div(exp_avg32, denom)
|
738
|
-
|
739
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
740
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
741
773
|
copy_stochastic_list_(grad, u32)
|
742
774
|
|
743
775
|
|
@@ -749,28 +781,26 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
|
|
749
781
|
|
750
782
|
|
751
783
|
@decorator_knowngood
|
752
|
-
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
753
|
-
beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
|
754
|
-
caution: bool):
|
784
|
+
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
785
|
+
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
|
786
|
+
eps: Tensor, caution: bool):
|
755
787
|
beta1 = beta_debias(beta1, step)
|
756
788
|
beta2 = beta_debias(beta2, step)
|
757
789
|
|
758
|
-
|
790
|
+
u32, g32 = [list(map(promote, x)) for x in [update, grad]]
|
759
791
|
|
760
|
-
|
761
|
-
denom = exp_avg_sq_(
|
792
|
+
exp_avg32 = _lerp32(exp_avg, u32, beta1)
|
793
|
+
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
762
794
|
u32 = torch._foreach_div(exp_avg32, denom)
|
763
|
-
|
764
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
765
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
766
|
-
_compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
|
795
|
+
_compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
|
767
796
|
|
768
797
|
|
769
|
-
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
770
|
-
beta2: float, step: int, lr: float, eps: float, decay: float,
|
798
|
+
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
799
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
|
800
|
+
caution: bool):
|
771
801
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
772
802
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
773
|
-
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
|
803
|
+
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
774
804
|
|
775
805
|
|
776
806
|
@decorator_knowngood
|
@@ -779,14 +809,13 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
779
809
|
beta1 = beta_debias(beta1, step)
|
780
810
|
beta2 = beta_debias(beta2, step)
|
781
811
|
|
782
|
-
gp32
|
812
|
+
gp32 = list(map(promote, grad))
|
783
813
|
|
784
|
-
denom = exp_avg_sq_(
|
814
|
+
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
|
785
815
|
gp32 = torch._foreach_div(gp32, denom)
|
786
|
-
|
816
|
+
gp32 = _lerp32(exp_avg, gp32, beta1)
|
787
817
|
|
788
|
-
copy_stochastic_list_(
|
789
|
-
copy_stochastic_list_(grad, exp_avg)
|
818
|
+
copy_stochastic_list_(grad, gp32)
|
790
819
|
|
791
820
|
|
792
821
|
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
@@ -797,52 +826,50 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
797
826
|
|
798
827
|
|
799
828
|
@decorator_knowngood
|
800
|
-
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
801
|
-
|
802
|
-
|
829
|
+
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
830
|
+
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
|
831
|
+
caution: bool):
|
803
832
|
beta1 = beta_debias(beta1, step)
|
804
833
|
beta2 = beta_debias(beta2, step)
|
805
834
|
|
806
|
-
|
835
|
+
u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
|
807
836
|
|
808
|
-
denom = exp_avg_sq_(
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
837
|
+
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
838
|
+
u32 = torch._foreach_div(u32, denom)
|
839
|
+
u32 = _lerp32(exp_avg, u32, beta1)
|
840
|
+
_compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
|
814
841
|
|
815
842
|
|
816
|
-
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
817
|
-
beta2: float, step: int, lr: float, decay: float, caution: bool):
|
843
|
+
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
844
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
|
818
845
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
819
846
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
820
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
|
847
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
|
821
848
|
|
822
849
|
|
823
850
|
@decorator_knowngood
|
824
|
-
def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
825
|
-
|
826
|
-
|
851
|
+
def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
852
|
+
u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
|
853
|
+
_compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
|
827
854
|
|
828
855
|
beta1 = beta_debias(beta1, step)
|
829
856
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
830
857
|
[denom.clamp_(min=eps) for denom in denom]
|
831
|
-
torch.
|
832
|
-
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32,
|
858
|
+
exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
|
859
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
|
860
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
833
861
|
|
834
862
|
beta2 = beta_debias(beta2, step + 1)
|
835
|
-
torch.
|
836
|
-
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32,
|
837
|
-
|
838
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
863
|
+
exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
|
864
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
|
839
865
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
840
866
|
|
841
867
|
|
842
|
-
|
868
|
+
|
869
|
+
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
843
870
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
844
871
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
845
|
-
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
872
|
+
_fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
846
873
|
|
847
874
|
|
848
875
|
@decorator_knowngood
|
@@ -853,21 +880,21 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
853
880
|
beta1 = beta_debias(beta1, step)
|
854
881
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
855
882
|
[denom.clamp_(min=1e-8) for denom in denom]
|
856
|
-
torch.
|
883
|
+
exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
|
857
884
|
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
885
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
858
886
|
|
859
887
|
beta2 = beta_debias(beta2, step + 1)
|
860
|
-
torch.
|
888
|
+
exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
|
861
889
|
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
862
|
-
|
863
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
864
890
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
891
|
+
|
865
892
|
copy_stochastic_list_(grad, update)
|
866
893
|
|
867
894
|
|
868
895
|
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
869
|
-
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad
|
870
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step,
|
896
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
897
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
871
898
|
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
872
899
|
return grad
|
873
900
|
|
@@ -912,7 +939,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
|
|
912
939
|
|
913
940
|
for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
|
914
941
|
if caution:
|
915
|
-
|
942
|
+
u32_ = _compilable_cautioning(promote(g_), u32_)
|
916
943
|
add_fn(p32_, u32_, lr)
|
917
944
|
|
918
945
|
copy_stochastic_list_(p, p32)
|
@@ -1228,7 +1255,7 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1228
1255
|
prob = prob(group[f'{name}_prob_step'])
|
1229
1256
|
if group['stochastic_schedule']:
|
1230
1257
|
return rng.random() < prob
|
1231
|
-
cumulative_prob =
|
1258
|
+
cumulative_prob = group.get(name, 0)
|
1232
1259
|
group[name] = cumulative_prob + prob
|
1233
1260
|
return int(group[name]) > int(cumulative_prob)
|
1234
1261
|
|
@@ -1289,15 +1316,16 @@ def mars_correction(g, old_g, beta1, gamma):
|
|
1289
1316
|
|
1290
1317
|
|
1291
1318
|
@decorator_knowngood
|
1292
|
-
def
|
1293
|
-
mask = (
|
1294
|
-
update.
|
1295
|
-
scale = mask.numel() / mask.sum().clamp(min=1)
|
1319
|
+
def _compilable_cautioning(g: Tensor, update: Tensor):
|
1320
|
+
mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
|
1321
|
+
update = update.masked_fill(mask, 0)
|
1322
|
+
scale = mask.numel() / (mask.numel() - mask.sum()).clamp(min=1)
|
1296
1323
|
update.mul_(scale)
|
1324
|
+
return update
|
1297
1325
|
|
1298
1326
|
|
1299
1327
|
def caution(g, update):
|
1300
|
-
|
1328
|
+
return _compilable_cautioning(g, update)
|
1301
1329
|
|
1302
1330
|
|
1303
1331
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
|
@@ -11,13 +11,14 @@ from heavyball.utils import clean, set_torch
|
|
11
11
|
from torch import nn
|
12
12
|
from torch._dynamo import config
|
13
13
|
|
14
|
+
heavyball.utils.zeroth_power_mode = 'newtonschulz'
|
14
15
|
heavyball.utils.compile_mode = 'default'
|
15
16
|
config.cache_size_limit = 128
|
16
17
|
|
17
18
|
|
18
19
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
19
20
|
@pytest.mark.parametrize("size,depth", [(128, 1)])
|
20
|
-
def test_foreach(opt, size, depth: int, iterations: int =
|
21
|
+
def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations: int = 1):
|
21
22
|
set_torch()
|
22
23
|
opt = getattr(heavyball, opt)
|
23
24
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|