heavyball 1.6.3__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +515 -100
- heavyball/chainable.py +487 -156
- heavyball/optimizations/__init__.py +38 -0
- heavyball/optimizations/integrator.py +169 -0
- heavyball/optimizations/optimizations.py +329 -0
- heavyball/utils.py +780 -241
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/METADATA +3 -2
- heavyball-1.7.1.dist-info/RECORD +11 -0
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/WHEEL +1 -1
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info/licenses}/LICENSE +1 -1
- heavyball-1.6.3.dist-info/RECORD +0 -8
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
import functools
|
2
|
+
import math
|
2
3
|
import random
|
3
|
-
from typing import
|
4
|
+
from typing import List, Literal, Optional, Union
|
4
5
|
|
5
6
|
import torch
|
7
|
+
from torch import Tensor
|
6
8
|
|
7
9
|
from . import utils
|
8
10
|
|
@@ -42,7 +44,7 @@ class FunctionTransform:
|
|
42
44
|
raise NotImplementedError
|
43
45
|
|
44
46
|
def get_fn(self):
|
45
|
-
if
|
47
|
+
if utils.hasattr_none(self.fn, "get_fn"):
|
46
48
|
return self.fn.get_fn()
|
47
49
|
return self.fn
|
48
50
|
|
@@ -55,7 +57,7 @@ def _zero_guard(state, key, ref, dtype):
|
|
55
57
|
|
56
58
|
|
57
59
|
def _storage_dtype(group):
|
58
|
-
dtype = group.get(
|
60
|
+
dtype = group.get("storage_dtype", "float32")
|
59
61
|
return getattr(torch, dtype)
|
60
62
|
|
61
63
|
|
@@ -65,8 +67,10 @@ class ZeroGuard(FunctionTransform):
|
|
65
67
|
self.names = names
|
66
68
|
|
67
69
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
68
|
-
vars = [
|
69
|
-
|
70
|
+
vars = [
|
71
|
+
[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
|
72
|
+
for name in self.names
|
73
|
+
]
|
70
74
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
71
75
|
|
72
76
|
|
@@ -78,8 +82,10 @@ class CopyGuard(FunctionTransform):
|
|
78
82
|
|
79
83
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
80
84
|
val = [update, grad, param, *args][self.index]
|
81
|
-
vars = [
|
82
|
-
|
85
|
+
vars = [
|
86
|
+
[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
|
87
|
+
for name in self.names
|
88
|
+
]
|
83
89
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
84
90
|
|
85
91
|
|
@@ -152,145 +158,243 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
152
158
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
153
159
|
|
154
160
|
|
155
|
-
@zero_guard(
|
161
|
+
@zero_guard("exp_avg")
|
156
162
|
@no_state
|
157
163
|
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
158
|
-
utils.weight_decay_to_ema_(
|
159
|
-
|
164
|
+
utils.weight_decay_to_ema_(
|
165
|
+
exp_avg,
|
166
|
+
update,
|
167
|
+
utils.beta_debias(group["ema_beta"], group["step"]),
|
168
|
+
group["weight_decay_to_ema"] * group["lr"],
|
169
|
+
)
|
160
170
|
return update
|
161
171
|
|
162
172
|
|
163
|
-
@zero_guard(
|
173
|
+
@zero_guard("exp_avg")
|
164
174
|
@no_state
|
165
175
|
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
166
|
-
utils.l1_weight_decay_to_ema_(
|
167
|
-
|
176
|
+
utils.l1_weight_decay_to_ema_(
|
177
|
+
exp_avg,
|
178
|
+
update,
|
179
|
+
utils.beta_debias(group["ema_beta"], group["step"]),
|
180
|
+
group["weight_decay_to_ema"] * group["lr"],
|
181
|
+
)
|
168
182
|
return update
|
169
183
|
|
170
184
|
|
171
185
|
@zero_guard("exp_avg_sq")
|
172
186
|
@no_state
|
173
187
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
174
|
-
return utils.scale_by_exp_avg_sq_(
|
175
|
-
|
188
|
+
return utils.scale_by_exp_avg_sq_(
|
189
|
+
exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"]
|
190
|
+
)
|
176
191
|
|
177
192
|
|
178
193
|
@zero_guard("exp_avg", "exp_avg_sq")
|
179
194
|
@no_state
|
180
195
|
def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
181
|
-
return utils.adam_(
|
182
|
-
|
196
|
+
return utils.adam_(
|
197
|
+
exp_avg,
|
198
|
+
exp_avg_sq,
|
199
|
+
update,
|
200
|
+
utils.get_beta1(group),
|
201
|
+
utils.get_beta2(group),
|
202
|
+
group["step"], #
|
203
|
+
group["eps"],
|
204
|
+
)
|
183
205
|
|
184
206
|
|
185
207
|
@zero_guard("exp_avg", "exp_avg_sq")
|
186
208
|
@no_state
|
187
209
|
def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
188
|
-
utils.fused_adam_(
|
189
|
-
|
210
|
+
utils.fused_adam_(
|
211
|
+
param,
|
212
|
+
exp_avg,
|
213
|
+
exp_avg_sq,
|
214
|
+
update,
|
215
|
+
grad,
|
216
|
+
utils.get_beta1(group),
|
217
|
+
utils.get_beta2(group),
|
218
|
+
group["step"],
|
219
|
+
group["lr"],
|
220
|
+
group["eps"],
|
221
|
+
group["weight_decay"],
|
222
|
+
group["caution"],
|
223
|
+
)
|
190
224
|
raise SkipUpdate
|
191
225
|
|
192
226
|
|
193
227
|
@zero_guard("exp_avg", "exp_avg_sq")
|
194
228
|
@no_state
|
195
229
|
def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
196
|
-
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group[
|
230
|
+
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group["step"])
|
197
231
|
|
198
232
|
|
199
233
|
@zero_guard("exp_avg", "exp_avg_sq")
|
200
234
|
@no_state
|
201
235
|
def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
202
|
-
utils.fused_laprop_(
|
203
|
-
|
236
|
+
utils.fused_laprop_(
|
237
|
+
param,
|
238
|
+
exp_avg,
|
239
|
+
exp_avg_sq,
|
240
|
+
update,
|
241
|
+
grad,
|
242
|
+
utils.get_beta1(group),
|
243
|
+
utils.get_beta2(group),
|
244
|
+
group["step"],
|
245
|
+
group["lr"],
|
246
|
+
group["weight_decay"],
|
247
|
+
group["caution"],
|
248
|
+
)
|
204
249
|
raise SkipUpdate
|
205
250
|
|
206
251
|
|
207
252
|
@no_state
|
208
253
|
def orthogonalize_grad_to_param(group, update, grad, param):
|
209
|
-
return utils.orthogonalize_grad_to_param(param, update, group[
|
254
|
+
return utils.orthogonalize_grad_to_param(param, update, group["eps"])
|
210
255
|
|
211
256
|
|
212
257
|
@copy_guard(2, "z")
|
213
258
|
@no_state
|
214
259
|
def update_by_schedule_free(group, update, grad, param, z):
|
215
|
-
group[
|
216
|
-
|
217
|
-
|
260
|
+
group["weight_sum"] = utils.schedule_free_(
|
261
|
+
group["lr"],
|
262
|
+
group["weight_lr_power"],
|
263
|
+
group.get("weight_sum", 0),
|
264
|
+
utils.get_beta1(group),
|
265
|
+
param,
|
266
|
+
z,
|
267
|
+
update,
|
268
|
+
grad,
|
269
|
+
group["caution"],
|
270
|
+
group["r"],
|
271
|
+
group["step"],
|
272
|
+
group["weight_decay"],
|
273
|
+
)
|
218
274
|
raise SkipUpdate
|
219
275
|
|
220
276
|
|
221
277
|
@zero_guard("exp_avg", "exp_avg_sq")
|
222
278
|
@no_state
|
223
279
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
224
|
-
if group[
|
225
|
-
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group[
|
280
|
+
if group["step"] == 1:
|
281
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
226
282
|
raise SkipUpdate
|
227
283
|
|
228
|
-
if group[
|
284
|
+
if group["step"] == 2:
|
229
285
|
update = utils.promote(update)
|
230
286
|
easq = utils.promote(exp_avg_sq)
|
231
|
-
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group[
|
232
|
-
utils.scale_by_exp_avg_sq_(
|
233
|
-
|
287
|
+
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
288
|
+
utils.scale_by_exp_avg_sq_(
|
289
|
+
exp_avg_sq,
|
290
|
+
update,
|
291
|
+
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
292
|
+
group["eps"],
|
293
|
+
)
|
234
294
|
raise SkipUpdate
|
235
295
|
|
236
|
-
utils.fused_adopt_(
|
237
|
-
|
296
|
+
utils.fused_adopt_(
|
297
|
+
param,
|
298
|
+
update,
|
299
|
+
grad,
|
300
|
+
exp_avg_sq,
|
301
|
+
exp_avg,
|
302
|
+
utils.get_beta1(group),
|
303
|
+
utils.get_beta2(group),
|
304
|
+
group["step"] - 2,
|
305
|
+
group["lr"],
|
306
|
+
group["eps"],
|
307
|
+
group["weight_decay"],
|
308
|
+
group["caution"],
|
309
|
+
)
|
238
310
|
raise SkipUpdate
|
239
311
|
|
240
312
|
|
241
313
|
@zero_guard("exp_avg", "exp_avg_sq")
|
242
314
|
@no_state
|
243
315
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
244
|
-
if group[
|
245
|
-
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group[
|
316
|
+
if group["step"] == 1:
|
317
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
246
318
|
raise SkipUpdate
|
247
319
|
|
248
|
-
if group[
|
320
|
+
if group["step"] == 2:
|
249
321
|
update = utils.promote(update)
|
250
322
|
easq = utils.promote(exp_avg_sq)
|
251
|
-
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group[
|
252
|
-
utils.scale_by_exp_avg_sq_(
|
253
|
-
|
323
|
+
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
324
|
+
utils.scale_by_exp_avg_sq_(
|
325
|
+
exp_avg_sq,
|
326
|
+
update,
|
327
|
+
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
328
|
+
group["eps"],
|
329
|
+
)
|
254
330
|
raise SkipUpdate
|
255
331
|
|
256
|
-
return utils.adopt(
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
state["
|
332
|
+
return utils.adopt(
|
333
|
+
update,
|
334
|
+
exp_avg_sq,
|
335
|
+
exp_avg,
|
336
|
+
utils.get_beta1(group),
|
337
|
+
utils.get_beta2(group),
|
338
|
+
group["step"] - 2,
|
339
|
+
)
|
340
|
+
|
341
|
+
|
342
|
+
def _init_soap(state, group, update, grad, param, inner: str = ""):
|
343
|
+
utils.init_preconditioner(grad, state, group["max_precond_dim"], group["precondition_1d"])
|
344
|
+
|
345
|
+
|
346
|
+
def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
347
|
+
Q, state["exprs"] = utils.init_Q_exprs(
|
348
|
+
grad,
|
349
|
+
group["precond_init_scale"],
|
350
|
+
group["precond_init_scale_scale"],
|
351
|
+
group["max_size_triangular"],
|
352
|
+
group["min_ndim_triangular"],
|
353
|
+
group["memory_save_mode"],
|
354
|
+
getattr(param, "hessian_vector", None),
|
355
|
+
getattr(param, "vector", None),
|
356
|
+
dtype=getattr(torch, group["q_dtype"]),
|
357
|
+
)
|
358
|
+
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
268
359
|
|
269
360
|
if not cached:
|
270
361
|
return
|
271
362
|
|
272
|
-
state[
|
363
|
+
state["Q_cache"] = [torch.empty_like(q) for q in Q]
|
273
364
|
|
274
|
-
expr = [f
|
275
|
-
expr =
|
276
|
-
grad_expr =
|
277
|
-
out_expr =
|
278
|
-
expr = f
|
365
|
+
expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
|
366
|
+
expr = ",".join(expr)
|
367
|
+
grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
|
368
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
369
|
+
expr = f"{expr},{grad_expr}->{out_expr}"
|
279
370
|
|
280
|
-
state[
|
371
|
+
state["cache_expr"] = expr
|
281
372
|
|
282
373
|
|
283
|
-
def
|
284
|
-
|
285
|
-
|
286
|
-
|
374
|
+
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
375
|
+
state["U"], state["V"], state["d"] = utils.init_lra(
|
376
|
+
grad,
|
377
|
+
group["precond_init_scale"],
|
378
|
+
group["precond_init_scale_scale"],
|
379
|
+
group["rank"],
|
380
|
+
getattr(param, "hessian_vector", None),
|
381
|
+
getattr(param, "vector", None),
|
382
|
+
dtype=getattr(torch, group["q_dtype"]),
|
383
|
+
)
|
384
|
+
group["preconditioning_step"] = 0
|
385
|
+
|
386
|
+
|
387
|
+
def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
|
388
|
+
step = group["step"]
|
389
|
+
if "precondition_frequency" in group:
|
390
|
+
return step > 0 and step % group["precondition_frequency"] == 0
|
287
391
|
if isinstance(step, torch.Tensor):
|
288
392
|
utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
|
289
393
|
rng = random.Random(0x172381)
|
290
394
|
else:
|
291
395
|
rng = random.Random(0x172381 ^ step)
|
292
|
-
if
|
293
|
-
return utils.precond_schedule(step, group[
|
396
|
+
if "precond_scheduler" in group:
|
397
|
+
return utils.precond_schedule(step, group["precond_scheduler"], rng)
|
294
398
|
if prob is not None:
|
295
399
|
return utils.psgd_should_update(group, prob, rng, name=name)
|
296
400
|
raise ValueError("No preconditioner update schedule specified.")
|
@@ -313,17 +417,17 @@ def nesterov_momentum(group, updates, grads, params, momentum):
|
|
313
417
|
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
314
418
|
|
315
419
|
|
316
|
-
@zero_guard(
|
420
|
+
@zero_guard("momentum")
|
317
421
|
@no_state
|
318
422
|
def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
|
319
423
|
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
|
320
424
|
|
321
425
|
|
322
426
|
def _store_std(state, group, update, grad, param):
|
323
|
-
state[
|
427
|
+
state["init_std"] = torch.std(grad, dim=0)
|
324
428
|
|
325
429
|
|
326
|
-
@general_guard("init_std", init_fn=_store_std)
|
430
|
+
@general_guard("init_std", init_fn=_store_std, skip_first=False)
|
327
431
|
@no_state
|
328
432
|
def mup_approx(group, updates, grads, params, init_std):
|
329
433
|
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
@@ -332,31 +436,79 @@ def mup_approx(group, updates, grads, params, init_std):
|
|
332
436
|
return updates
|
333
437
|
|
334
438
|
|
439
|
+
def _init_delta(state, group, update, grad, param, log_space: bool):
|
440
|
+
val = group["initial_d"]
|
441
|
+
state["delta"] = torch.full((), math.log(val) if log_space else val, dtype=param.dtype, device=param.device)
|
442
|
+
|
443
|
+
|
444
|
+
def _init_full_delta(state, group, update, grad, param, log_space: bool):
|
445
|
+
val = group["initial_d"]
|
446
|
+
state["delta"] = torch.full_like(param, math.log(val) if log_space else val)
|
447
|
+
|
448
|
+
|
449
|
+
@zero_guard("state")
|
450
|
+
@general_guard("delta", init_fn=functools.partial(_init_delta, log_space=False), skip_first=False)
|
451
|
+
@no_state
|
452
|
+
def scale_by_d_adaptation(group, update, grad, param, state, delta):
|
453
|
+
utils.d_adaptation(grad, update, state, delta)
|
454
|
+
return update
|
455
|
+
|
456
|
+
|
457
|
+
@zero_guard("state")
|
458
|
+
@general_guard("delta", init_fn=functools.partial(_init_delta, log_space=True), skip_first=False)
|
459
|
+
@no_state
|
460
|
+
def scale_by_lr_adaptation(group, update, grad, param, state, delta):
|
461
|
+
utils.lr_adaptation(grad, update, state, delta, group["lr_lr"])
|
462
|
+
return update
|
463
|
+
|
464
|
+
|
465
|
+
@zero_guard("state")
|
466
|
+
@general_guard("delta", init_fn=functools.partial(_init_full_delta, log_space=True), skip_first=False)
|
467
|
+
@no_state
|
468
|
+
def scale_by_pointwise_lr_adaptation(group, update, grad, param, state, delta):
|
469
|
+
utils.pointwise_lr_adaptation(grad, update, state, delta, group["lr_lr"])
|
470
|
+
return update
|
471
|
+
|
472
|
+
|
335
473
|
@zero_guard("momentum")
|
336
474
|
@no_state
|
337
475
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
338
476
|
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
|
339
477
|
|
340
478
|
|
341
|
-
_optim_fns = {
|
479
|
+
_optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
|
342
480
|
|
343
481
|
|
344
482
|
@zero_guard("exp_avg", "exp_avg_sq")
|
345
483
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
346
484
|
@no_state
|
347
|
-
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str =
|
485
|
+
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
|
348
486
|
update = utils.promote(update) # Promote to highest precision if needed
|
349
487
|
|
350
488
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
351
489
|
fn = _optim_fns[inner]
|
352
|
-
precond = fn(
|
353
|
-
|
490
|
+
precond = fn(
|
491
|
+
exp_avg,
|
492
|
+
exp_avg_sq,
|
493
|
+
grad_projected,
|
494
|
+
utils.get_beta1(group),
|
495
|
+
utils.get_beta2(group),
|
496
|
+
group["step"] - 1,
|
497
|
+
group["eps"],
|
498
|
+
)
|
354
499
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
355
500
|
|
356
501
|
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
357
|
-
utils.update_preconditioner(
|
358
|
-
|
359
|
-
|
502
|
+
utils.update_preconditioner(
|
503
|
+
u,
|
504
|
+
q,
|
505
|
+
gg,
|
506
|
+
ea,
|
507
|
+
group["max_precond_dim"],
|
508
|
+
group["precondition_1d"],
|
509
|
+
utils.beta_debias(group["shampoo_beta"], group["step"]),
|
510
|
+
group["is_preconditioning"],
|
511
|
+
)
|
360
512
|
return precond
|
361
513
|
|
362
514
|
|
@@ -364,17 +516,28 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
364
516
|
if prob is None:
|
365
517
|
prob = utils.precond_update_prob_schedule()
|
366
518
|
|
367
|
-
if not group[
|
519
|
+
if not group["is_preconditioning"]:
|
368
520
|
return Q_mat
|
369
521
|
|
370
|
-
utils.
|
371
|
-
|
372
|
-
if hasattr(param, 'vector'):
|
522
|
+
if utils.hasattr_none(param, "vector"):
|
523
|
+
vector, hessian_vector = param.vector, param.hessian_vector
|
373
524
|
del param.vector
|
374
525
|
del param.hessian_vector
|
526
|
+
else:
|
527
|
+
vector, hessian_vector = utils.dampen_grad(grad)
|
528
|
+
|
529
|
+
utils.psgd_update_precond(
|
530
|
+
Q_mat,
|
531
|
+
exprs,
|
532
|
+
hessian_vector,
|
533
|
+
group["precond_lr"],
|
534
|
+
Q,
|
535
|
+
group["store_triu_as_line"],
|
536
|
+
vector,
|
537
|
+
)
|
375
538
|
|
376
539
|
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
377
|
-
if group[
|
540
|
+
if group["store_triu_as_line"]:
|
378
541
|
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
379
542
|
else:
|
380
543
|
utils.psgd_balance_Q(Q)
|
@@ -382,8 +545,8 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
382
545
|
if isinstance(prob, float):
|
383
546
|
float_prob = prob
|
384
547
|
else:
|
385
|
-
float_prob = prob(group.get(f
|
386
|
-
group[
|
548
|
+
float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
|
549
|
+
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
387
550
|
|
388
551
|
if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
|
389
552
|
return _update_psgd_cache(cached, Q_cache, Q_mat)
|
@@ -403,51 +566,172 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
403
566
|
|
404
567
|
|
405
568
|
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
|
406
|
-
if group.get(
|
407
|
-
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group[
|
408
|
-
|
409
|
-
|
569
|
+
if group.get("is_cached", False):
|
570
|
+
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad)
|
571
|
+
else:
|
572
|
+
out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad)
|
573
|
+
group["caution"] = False # we already cautioned here - shouldn't do it again
|
410
574
|
return out
|
411
575
|
|
412
576
|
|
413
577
|
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
|
414
|
-
if group.get(
|
415
|
-
utils.fused_precond_grad_cached_(
|
416
|
-
|
578
|
+
if group.get("is_cached", False):
|
579
|
+
utils.fused_precond_grad_cached_(
|
580
|
+
cache_expr,
|
581
|
+
update,
|
582
|
+
param,
|
583
|
+
group["lr"],
|
584
|
+
grad,
|
585
|
+
group["weight_decay"],
|
586
|
+
group["caution"],
|
587
|
+
*Q_cache,
|
588
|
+
)
|
417
589
|
else:
|
418
|
-
utils.fused_psgd_precond_grad(
|
419
|
-
|
590
|
+
utils.fused_psgd_precond_grad(
|
591
|
+
exprs[-1],
|
592
|
+
update,
|
593
|
+
param,
|
594
|
+
group["lr"],
|
595
|
+
grad,
|
596
|
+
group["weight_decay"],
|
597
|
+
group["caution"],
|
598
|
+
*Q_mat,
|
599
|
+
)
|
600
|
+
|
601
|
+
|
602
|
+
def _update_lra(
|
603
|
+
group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
|
604
|
+
):
|
605
|
+
if not group["is_preconditioning"]:
|
606
|
+
return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
|
607
|
+
|
608
|
+
if utils.hasattr_none(params[0], "hessian_vector"):
|
609
|
+
vector = utils.flatten([p.vector for p in params])
|
610
|
+
hessian_vector = utils.flatten([p.hessian_vector for p in params])
|
611
|
+
for p in params:
|
612
|
+
del p.vector
|
613
|
+
del p.hessian_vector
|
614
|
+
else:
|
615
|
+
vector, hessian_vector = utils.dampen_multiple(grads)
|
616
|
+
return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
|
617
|
+
|
618
|
+
|
619
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
620
|
+
@no_state
|
621
|
+
def scale_by_psgd_lra(group, update, grad, param, U, V, d):
|
622
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
|
623
|
+
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
624
|
+
|
625
|
+
|
626
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
627
|
+
@no_state
|
628
|
+
def update_by_psgd_lra(group, update, grad, param, U, V, d):
|
629
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
|
630
|
+
utils.apply_lra_update(param, update, u, v, d)
|
631
|
+
raise SkipUpdate
|
632
|
+
|
633
|
+
|
634
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
635
|
+
@no_state
|
636
|
+
def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
637
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
|
638
|
+
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
639
|
+
|
640
|
+
|
641
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
642
|
+
@no_state
|
643
|
+
def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
644
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
|
645
|
+
utils.apply_lra_update(param, update, u, v, d)
|
646
|
+
raise SkipUpdate
|
420
647
|
|
421
648
|
|
422
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
649
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
423
650
|
@no_state_no_foreach
|
424
|
-
def scale_by_psgd(
|
425
|
-
|
651
|
+
def scale_by_psgd(
|
652
|
+
group,
|
653
|
+
update,
|
654
|
+
grad,
|
655
|
+
param,
|
656
|
+
Q,
|
657
|
+
exprs,
|
658
|
+
Q_cache,
|
659
|
+
cache_expr: str,
|
660
|
+
cached: bool = False,
|
661
|
+
prob: Optional[callable] = None,
|
662
|
+
):
|
426
663
|
update = update.to(memory_format=torch.contiguous_format)
|
427
|
-
Q_mat = utils.line_to_triu(Q) if group[
|
428
|
-
Q_mat = _update_psgd_precond(
|
429
|
-
|
664
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
665
|
+
Q_mat = _update_psgd_precond(
|
666
|
+
cached,
|
667
|
+
Q_cache,
|
668
|
+
group,
|
669
|
+
param,
|
670
|
+
update if group["momentum_into_precond_update"] else grad,
|
671
|
+
Q_mat,
|
672
|
+
Q,
|
673
|
+
exprs,
|
674
|
+
prob,
|
675
|
+
)
|
430
676
|
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
431
677
|
|
432
678
|
|
433
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
679
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
434
680
|
@no_state_no_foreach
|
435
|
-
def scale_by_delayed_psgd(
|
436
|
-
|
437
|
-
|
681
|
+
def scale_by_delayed_psgd(
|
682
|
+
group,
|
683
|
+
update,
|
684
|
+
grad,
|
685
|
+
param,
|
686
|
+
Q,
|
687
|
+
exprs,
|
688
|
+
Q_cache,
|
689
|
+
cache_expr: str,
|
690
|
+
cached: bool = False,
|
691
|
+
prob: Optional[callable] = None,
|
692
|
+
):
|
693
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
438
694
|
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
439
|
-
_ = _update_psgd_precond(
|
440
|
-
|
695
|
+
_ = _update_psgd_precond(
|
696
|
+
cached,
|
697
|
+
Q_cache,
|
698
|
+
group,
|
699
|
+
param,
|
700
|
+
update if group["momentum_into_precond_update"] else grad,
|
701
|
+
Q_mat,
|
702
|
+
Q,
|
703
|
+
exprs,
|
704
|
+
prob,
|
705
|
+
)
|
441
706
|
return precond
|
442
707
|
|
443
708
|
|
444
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
709
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
445
710
|
@no_state_no_foreach
|
446
|
-
def update_by_psgd(
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
711
|
+
def update_by_psgd(
|
712
|
+
group,
|
713
|
+
update,
|
714
|
+
grad,
|
715
|
+
param,
|
716
|
+
Q,
|
717
|
+
exprs,
|
718
|
+
Q_cache,
|
719
|
+
cache_expr: str,
|
720
|
+
cached: bool = False,
|
721
|
+
prob: Optional[callable] = None,
|
722
|
+
):
|
723
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
724
|
+
Q_mat = _update_psgd_precond(
|
725
|
+
cached,
|
726
|
+
Q_cache,
|
727
|
+
group,
|
728
|
+
param,
|
729
|
+
update if group["momentum_into_precond_update"] else grad,
|
730
|
+
Q_mat,
|
731
|
+
Q,
|
732
|
+
exprs,
|
733
|
+
prob,
|
734
|
+
)
|
451
735
|
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
452
736
|
raise SkipUpdate
|
453
737
|
|
@@ -457,20 +741,39 @@ def sign(group, update, grad, param, graft: bool = True):
|
|
457
741
|
return utils.sign_(update, graft)
|
458
742
|
|
459
743
|
|
460
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
744
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
461
745
|
@no_state_no_foreach
|
462
|
-
def update_by_delayed_psgd(
|
463
|
-
|
464
|
-
|
746
|
+
def update_by_delayed_psgd(
|
747
|
+
group,
|
748
|
+
update,
|
749
|
+
grad,
|
750
|
+
param,
|
751
|
+
Q,
|
752
|
+
exprs,
|
753
|
+
Q_cache,
|
754
|
+
cache_expr: str,
|
755
|
+
cached: bool = False,
|
756
|
+
prob: Optional[callable] = None,
|
757
|
+
):
|
758
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
465
759
|
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
466
|
-
_ = _update_psgd_precond(
|
467
|
-
|
760
|
+
_ = _update_psgd_precond(
|
761
|
+
cached,
|
762
|
+
Q_cache,
|
763
|
+
group,
|
764
|
+
param,
|
765
|
+
update if group["momentum_into_precond_update"] else grad,
|
766
|
+
Q_mat,
|
767
|
+
Q,
|
768
|
+
exprs,
|
769
|
+
prob,
|
770
|
+
)
|
468
771
|
raise SkipUpdate
|
469
772
|
|
470
773
|
|
471
774
|
def palm_beta2(state, group, update, grad, param):
|
472
|
-
beta2 = 1 - group[
|
473
|
-
group[
|
775
|
+
beta2 = 1 - group["step"] ** -group["beta2_scale"]
|
776
|
+
group["betas"] = (utils.get_beta1(group), beta2)
|
474
777
|
return update
|
475
778
|
|
476
779
|
|
@@ -499,7 +802,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
499
802
|
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
500
803
|
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
|
501
804
|
if not skip_update and update is not None:
|
502
|
-
utils.update_param_(param, update, group[
|
805
|
+
utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
|
503
806
|
|
504
807
|
|
505
808
|
def create_branch(branches: List[List[callable]], merge_fn: callable):
|
@@ -524,14 +827,16 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
524
827
|
self.fns = tuple(fns)
|
525
828
|
|
526
829
|
def _step(self, group):
|
527
|
-
if
|
528
|
-
group[
|
529
|
-
if
|
530
|
-
utils.warn_once(
|
531
|
-
|
532
|
-
|
830
|
+
if "base_lr" not in group:
|
831
|
+
group["base_lr"] = group["lr"]
|
832
|
+
if "prev_lr" in group and group["prev_lr"] != group["lr"]:
|
833
|
+
utils.warn_once(
|
834
|
+
f"Learning rate changed between steps. This is an experimental feature and "
|
835
|
+
f"only supported with foreach=True (currently foreach={group['foreach']})."
|
836
|
+
)
|
837
|
+
group["base_lr"] = group["lr"]
|
533
838
|
|
534
|
-
caution = group[
|
839
|
+
caution = group["caution"]
|
535
840
|
|
536
841
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
537
842
|
|
@@ -541,26 +846,26 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
541
846
|
|
542
847
|
for param in p:
|
543
848
|
state = self.state_(param)
|
544
|
-
if
|
545
|
-
step = state[
|
849
|
+
if "step" in state:
|
850
|
+
step = state["step"]
|
546
851
|
elif self.compile_step:
|
547
852
|
step = utils.scalar_guard(0, param)
|
548
853
|
else:
|
549
854
|
step = 0
|
550
855
|
break
|
551
856
|
|
552
|
-
group[
|
553
|
-
group[
|
857
|
+
group["step"] = state["step"] = step = step + 1
|
858
|
+
group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
|
554
859
|
|
555
|
-
if not group[
|
860
|
+
if not group["foreach"] or len(p) == 1:
|
556
861
|
for param, grad in zip(p, g):
|
557
862
|
chain(self.state_, group, [grad], [param], *self.fns)
|
558
863
|
else:
|
559
864
|
chain(self.state_, group, g, p, *self.fns)
|
560
865
|
|
561
|
-
group[
|
562
|
-
group[
|
563
|
-
group[
|
866
|
+
group["caution"] = caution
|
867
|
+
group["lr"] = group["prev_lr"]
|
868
|
+
group["step"] = None
|
564
869
|
|
565
870
|
|
566
871
|
use_default = object()
|
@@ -571,7 +876,13 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
|
|
571
876
|
name = default(name, default_val)
|
572
877
|
if callable(name):
|
573
878
|
return name
|
574
|
-
elif name not in (
|
879
|
+
elif name not in (
|
880
|
+
"l2_clip_",
|
881
|
+
"rmsnorm_clip_",
|
882
|
+
"trust_region_clip_",
|
883
|
+
"a_law_compress",
|
884
|
+
"mu_law_compress",
|
885
|
+
):
|
575
886
|
raise ValueError(f"Clipping function {name} not found")
|
576
887
|
return getattr(utils, name)
|
577
888
|
|
@@ -581,16 +892,24 @@ def default(a, b):
|
|
581
892
|
|
582
893
|
|
583
894
|
# not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
|
584
|
-
_scale_to_update_map = {
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
895
|
+
_scale_to_update_map = {
|
896
|
+
scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
|
897
|
+
scale_by_psgd.get_fn(): update_by_psgd, #
|
898
|
+
scale_by_psgd_lra.get_fn(): update_by_psgd_lra, #
|
899
|
+
scale_by_delayed_psgd_lra.get_fn(): update_by_delayed_psgd_lra, #
|
900
|
+
scale_by_adam.get_fn(): update_by_adam, #
|
901
|
+
scale_by_laprop.get_fn(): update_by_laprop, #
|
902
|
+
scale_by_adopt.get_fn(): update_by_adopt, #
|
903
|
+
}
|
904
|
+
_scale_to_update_map_inv = {
|
905
|
+
update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
|
906
|
+
update_by_psgd.get_fn(): scale_by_psgd, #
|
907
|
+
update_by_psgd_lra.get_fn(): scale_by_psgd_lra, #
|
908
|
+
update_by_delayed_psgd_lra.get_fn(): scale_by_delayed_psgd_lra, #
|
909
|
+
update_by_adam.get_fn(): scale_by_adam, #
|
910
|
+
update_by_laprop.get_fn(): scale_by_laprop, #
|
911
|
+
update_by_adopt.get_fn(): scale_by_adopt, #
|
912
|
+
}
|
594
913
|
|
595
914
|
|
596
915
|
class BaseOpt(ChainOpt):
|
@@ -622,8 +941,18 @@ class BaseOpt(ChainOpt):
|
|
622
941
|
palm: bool = False
|
623
942
|
auto_fuse: bool = True
|
624
943
|
|
625
|
-
def __init__(
|
626
|
-
|
944
|
+
def __init__(
|
945
|
+
self,
|
946
|
+
params,
|
947
|
+
defaults,
|
948
|
+
foreach: bool,
|
949
|
+
gradient_clipping: str_or_fn,
|
950
|
+
update_clipping: str_or_fn,
|
951
|
+
palm: bool = use_default,
|
952
|
+
*fns,
|
953
|
+
compile_step: bool = use_default,
|
954
|
+
promote: bool = use_default,
|
955
|
+
):
|
627
956
|
if not fns:
|
628
957
|
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
629
958
|
|
@@ -643,8 +972,10 @@ class BaseOpt(ChainOpt):
|
|
643
972
|
fns = tuple(fns)[:-1] + (fn,)
|
644
973
|
elif fn in _scale_to_update_map_inv:
|
645
974
|
if not self.auto_fuse:
|
646
|
-
raise ValueError(
|
647
|
-
|
975
|
+
raise ValueError(
|
976
|
+
"update_clipping is currently not compatible with update_by_* functions. "
|
977
|
+
"Manually select scale_by_* functions or set auto_fuse=True."
|
978
|
+
)
|
648
979
|
fn = _scale_to_update_map_inv[fn]
|
649
980
|
if args is not None:
|
650
981
|
fn = functools.partial(fn, *args, **kwargs)
|
@@ -665,27 +996,27 @@ class BaseOpt(ChainOpt):
|
|
665
996
|
class ScheduleFree(BaseOpt):
|
666
997
|
def eval(self):
|
667
998
|
for group in self.param_groups:
|
668
|
-
group[
|
999
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
669
1000
|
beta1 = utils.get_beta1(group)
|
670
1001
|
if beta1 > 0 and not train_mode:
|
671
|
-
for p in group[
|
1002
|
+
for p in group["params"]:
|
672
1003
|
state = self.state_(p)
|
673
|
-
if
|
1004
|
+
if "z" in state:
|
674
1005
|
# Set p.data to x
|
675
|
-
z = utils.promote(state[
|
1006
|
+
z = utils.promote(state["z"])
|
676
1007
|
p32 = utils.promote(p.data)
|
677
1008
|
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
678
1009
|
utils.copy_stochastic_(p.data, p32)
|
679
1010
|
|
680
1011
|
def train(self):
|
681
1012
|
for group in self.param_groups:
|
682
|
-
group[
|
1013
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
683
1014
|
beta1 = utils.get_beta1(group)
|
684
1015
|
if beta1 > 0 and train_mode:
|
685
|
-
for p in group[
|
1016
|
+
for p in group["params"]:
|
686
1017
|
state = self.state_(p)
|
687
|
-
if
|
688
|
-
z = utils.promote(state[
|
1018
|
+
if "z" in state:
|
1019
|
+
z = utils.promote(state["z"])
|
689
1020
|
p32 = utils.promote(p.data)
|
690
1021
|
p32.lerp_(end=z, weight=1 - beta1)
|
691
1022
|
utils.copy_stochastic_(p.data, p32)
|