heavyball 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +496 -100
- heavyball/chainable.py +444 -155
- heavyball/utils.py +326 -143
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info}/METADATA +11 -4
- heavyball-1.7.0.dist-info/RECORD +8 -0
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info}/WHEEL +1 -1
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info/licenses}/LICENSE +1 -1
- heavyball-1.6.2.dist-info/RECORD +0 -8
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
from typing import
|
3
|
+
from typing import List, Literal, Optional, Union
|
4
4
|
|
5
5
|
import torch
|
6
|
+
from torch import Tensor
|
6
7
|
|
7
8
|
from . import utils
|
8
9
|
|
@@ -42,7 +43,7 @@ class FunctionTransform:
|
|
42
43
|
raise NotImplementedError
|
43
44
|
|
44
45
|
def get_fn(self):
|
45
|
-
if hasattr(self.fn,
|
46
|
+
if hasattr(self.fn, "get_fn"):
|
46
47
|
return self.fn.get_fn()
|
47
48
|
return self.fn
|
48
49
|
|
@@ -55,7 +56,7 @@ def _zero_guard(state, key, ref, dtype):
|
|
55
56
|
|
56
57
|
|
57
58
|
def _storage_dtype(group):
|
58
|
-
dtype = group.get(
|
59
|
+
dtype = group.get("storage_dtype", "float32")
|
59
60
|
return getattr(torch, dtype)
|
60
61
|
|
61
62
|
|
@@ -65,8 +66,10 @@ class ZeroGuard(FunctionTransform):
|
|
65
66
|
self.names = names
|
66
67
|
|
67
68
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
68
|
-
vars = [
|
69
|
-
|
69
|
+
vars = [
|
70
|
+
[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
|
71
|
+
for name in self.names
|
72
|
+
]
|
70
73
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
71
74
|
|
72
75
|
|
@@ -78,8 +81,10 @@ class CopyGuard(FunctionTransform):
|
|
78
81
|
|
79
82
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
80
83
|
val = [update, grad, param, *args][self.index]
|
81
|
-
vars = [
|
82
|
-
|
84
|
+
vars = [
|
85
|
+
[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
|
86
|
+
for name in self.names
|
87
|
+
]
|
83
88
|
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
|
84
89
|
|
85
90
|
|
@@ -152,145 +157,243 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
152
157
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
153
158
|
|
154
159
|
|
155
|
-
@zero_guard(
|
160
|
+
@zero_guard("exp_avg")
|
156
161
|
@no_state
|
157
162
|
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
158
|
-
utils.weight_decay_to_ema_(
|
159
|
-
|
163
|
+
utils.weight_decay_to_ema_(
|
164
|
+
exp_avg,
|
165
|
+
update,
|
166
|
+
utils.beta_debias(group["ema_beta"], group["step"]),
|
167
|
+
group["weight_decay_to_ema"] * group["lr"],
|
168
|
+
)
|
160
169
|
return update
|
161
170
|
|
162
171
|
|
163
|
-
@zero_guard(
|
172
|
+
@zero_guard("exp_avg")
|
164
173
|
@no_state
|
165
174
|
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
166
|
-
utils.l1_weight_decay_to_ema_(
|
167
|
-
|
175
|
+
utils.l1_weight_decay_to_ema_(
|
176
|
+
exp_avg,
|
177
|
+
update,
|
178
|
+
utils.beta_debias(group["ema_beta"], group["step"]),
|
179
|
+
group["weight_decay_to_ema"] * group["lr"],
|
180
|
+
)
|
168
181
|
return update
|
169
182
|
|
170
183
|
|
171
184
|
@zero_guard("exp_avg_sq")
|
172
185
|
@no_state
|
173
186
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
174
|
-
return utils.scale_by_exp_avg_sq_(
|
175
|
-
|
187
|
+
return utils.scale_by_exp_avg_sq_(
|
188
|
+
exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"]
|
189
|
+
)
|
176
190
|
|
177
191
|
|
178
192
|
@zero_guard("exp_avg", "exp_avg_sq")
|
179
193
|
@no_state
|
180
194
|
def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
181
|
-
return utils.adam_(
|
182
|
-
|
195
|
+
return utils.adam_(
|
196
|
+
exp_avg,
|
197
|
+
exp_avg_sq,
|
198
|
+
update,
|
199
|
+
utils.get_beta1(group),
|
200
|
+
utils.get_beta2(group),
|
201
|
+
group["step"], #
|
202
|
+
group["eps"],
|
203
|
+
)
|
183
204
|
|
184
205
|
|
185
206
|
@zero_guard("exp_avg", "exp_avg_sq")
|
186
207
|
@no_state
|
187
208
|
def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
188
|
-
utils.fused_adam_(
|
189
|
-
|
209
|
+
utils.fused_adam_(
|
210
|
+
param,
|
211
|
+
exp_avg,
|
212
|
+
exp_avg_sq,
|
213
|
+
update,
|
214
|
+
grad,
|
215
|
+
utils.get_beta1(group),
|
216
|
+
utils.get_beta2(group),
|
217
|
+
group["step"],
|
218
|
+
group["lr"],
|
219
|
+
group["eps"],
|
220
|
+
group["weight_decay"],
|
221
|
+
group["caution"],
|
222
|
+
)
|
190
223
|
raise SkipUpdate
|
191
224
|
|
192
225
|
|
193
226
|
@zero_guard("exp_avg", "exp_avg_sq")
|
194
227
|
@no_state
|
195
228
|
def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
196
|
-
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group[
|
229
|
+
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group["step"])
|
197
230
|
|
198
231
|
|
199
232
|
@zero_guard("exp_avg", "exp_avg_sq")
|
200
233
|
@no_state
|
201
234
|
def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
202
|
-
utils.fused_laprop_(
|
203
|
-
|
235
|
+
utils.fused_laprop_(
|
236
|
+
param,
|
237
|
+
exp_avg,
|
238
|
+
exp_avg_sq,
|
239
|
+
update,
|
240
|
+
grad,
|
241
|
+
utils.get_beta1(group),
|
242
|
+
utils.get_beta2(group),
|
243
|
+
group["step"],
|
244
|
+
group["lr"],
|
245
|
+
group["weight_decay"],
|
246
|
+
group["caution"],
|
247
|
+
)
|
204
248
|
raise SkipUpdate
|
205
249
|
|
206
250
|
|
207
251
|
@no_state
|
208
252
|
def orthogonalize_grad_to_param(group, update, grad, param):
|
209
|
-
return utils.orthogonalize_grad_to_param(param, update, group[
|
253
|
+
return utils.orthogonalize_grad_to_param(param, update, group["eps"])
|
210
254
|
|
211
255
|
|
212
256
|
@copy_guard(2, "z")
|
213
257
|
@no_state
|
214
258
|
def update_by_schedule_free(group, update, grad, param, z):
|
215
|
-
group[
|
216
|
-
|
217
|
-
|
259
|
+
group["weight_sum"] = utils.schedule_free_(
|
260
|
+
group["lr"],
|
261
|
+
group["weight_lr_power"],
|
262
|
+
group.get("weight_sum", 0),
|
263
|
+
utils.get_beta1(group),
|
264
|
+
param,
|
265
|
+
z,
|
266
|
+
update,
|
267
|
+
grad,
|
268
|
+
group["caution"],
|
269
|
+
group["r"],
|
270
|
+
group["step"],
|
271
|
+
group["weight_decay"],
|
272
|
+
)
|
218
273
|
raise SkipUpdate
|
219
274
|
|
220
275
|
|
221
276
|
@zero_guard("exp_avg", "exp_avg_sq")
|
222
277
|
@no_state
|
223
278
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
224
|
-
if group[
|
225
|
-
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group[
|
279
|
+
if group["step"] == 1:
|
280
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
226
281
|
raise SkipUpdate
|
227
282
|
|
228
|
-
if group[
|
283
|
+
if group["step"] == 2:
|
229
284
|
update = utils.promote(update)
|
230
285
|
easq = utils.promote(exp_avg_sq)
|
231
|
-
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group[
|
232
|
-
utils.scale_by_exp_avg_sq_(
|
233
|
-
|
286
|
+
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
287
|
+
utils.scale_by_exp_avg_sq_(
|
288
|
+
exp_avg_sq,
|
289
|
+
update,
|
290
|
+
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
291
|
+
group["eps"],
|
292
|
+
)
|
234
293
|
raise SkipUpdate
|
235
294
|
|
236
|
-
utils.fused_adopt_(
|
237
|
-
|
295
|
+
utils.fused_adopt_(
|
296
|
+
param,
|
297
|
+
update,
|
298
|
+
grad,
|
299
|
+
exp_avg_sq,
|
300
|
+
exp_avg,
|
301
|
+
utils.get_beta1(group),
|
302
|
+
utils.get_beta2(group),
|
303
|
+
group["step"] - 2,
|
304
|
+
group["lr"],
|
305
|
+
group["eps"],
|
306
|
+
group["weight_decay"],
|
307
|
+
group["caution"],
|
308
|
+
)
|
238
309
|
raise SkipUpdate
|
239
310
|
|
240
311
|
|
241
312
|
@zero_guard("exp_avg", "exp_avg_sq")
|
242
313
|
@no_state
|
243
314
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
244
|
-
if group[
|
245
|
-
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group[
|
315
|
+
if group["step"] == 1:
|
316
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
|
246
317
|
raise SkipUpdate
|
247
318
|
|
248
|
-
if group[
|
319
|
+
if group["step"] == 2:
|
249
320
|
update = utils.promote(update)
|
250
321
|
easq = utils.promote(exp_avg_sq)
|
251
|
-
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group[
|
252
|
-
utils.scale_by_exp_avg_sq_(
|
253
|
-
|
322
|
+
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
323
|
+
utils.scale_by_exp_avg_sq_(
|
324
|
+
exp_avg_sq,
|
325
|
+
update,
|
326
|
+
utils.beta_debias(utils.get_beta2(group), group["step"]),
|
327
|
+
group["eps"],
|
328
|
+
)
|
254
329
|
raise SkipUpdate
|
255
330
|
|
256
|
-
return utils.adopt(
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
state["
|
331
|
+
return utils.adopt(
|
332
|
+
update,
|
333
|
+
exp_avg_sq,
|
334
|
+
exp_avg,
|
335
|
+
utils.get_beta1(group),
|
336
|
+
utils.get_beta2(group),
|
337
|
+
group["step"] - 2,
|
338
|
+
)
|
339
|
+
|
340
|
+
|
341
|
+
def _init_soap(state, group, update, grad, param, inner: str = ""):
|
342
|
+
utils.init_preconditioner(grad, state, group["max_precond_dim"], group["precondition_1d"])
|
343
|
+
|
344
|
+
|
345
|
+
def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
346
|
+
Q, state["exprs"] = utils.init_Q_exprs(
|
347
|
+
grad,
|
348
|
+
group["precond_init_scale"],
|
349
|
+
group["precond_init_scale_scale"],
|
350
|
+
group["max_size_triangular"],
|
351
|
+
group["min_ndim_triangular"],
|
352
|
+
group["memory_save_mode"],
|
353
|
+
getattr(param, "hessian_vector", None),
|
354
|
+
getattr(param, "vector", None),
|
355
|
+
dtype=getattr(torch, group["q_dtype"]),
|
356
|
+
)
|
357
|
+
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
268
358
|
|
269
359
|
if not cached:
|
270
360
|
return
|
271
361
|
|
272
|
-
state[
|
362
|
+
state["Q_cache"] = [torch.empty_like(q) for q in Q]
|
363
|
+
|
364
|
+
expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
|
365
|
+
expr = ",".join(expr)
|
366
|
+
grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
|
367
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
368
|
+
expr = f"{expr},{grad_expr}->{out_expr}"
|
273
369
|
|
274
|
-
|
275
|
-
expr = ','.join(expr)
|
276
|
-
grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape))
|
277
|
-
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
278
|
-
expr = f'{expr},{grad_expr}->{out_expr}'
|
370
|
+
state["cache_expr"] = expr
|
279
371
|
|
280
|
-
state['cache_expr'] = expr
|
281
372
|
|
373
|
+
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
374
|
+
state["U"], state["V"], state["d"] = utils.init_lra(
|
375
|
+
grad,
|
376
|
+
group["precond_init_scale"],
|
377
|
+
group["precond_init_scale_scale"],
|
378
|
+
group["rank"],
|
379
|
+
getattr(param, "hessian_vector", None),
|
380
|
+
getattr(param, "vector", None),
|
381
|
+
dtype=getattr(torch, group["q_dtype"]),
|
382
|
+
)
|
383
|
+
group["preconditioning_step"] = 0
|
282
384
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
385
|
+
|
386
|
+
def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
|
387
|
+
step = group["step"]
|
388
|
+
if "precondition_frequency" in group:
|
389
|
+
return step > 0 and step % group["precondition_frequency"] == 0
|
287
390
|
if isinstance(step, torch.Tensor):
|
288
391
|
utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
|
289
392
|
rng = random.Random(0x172381)
|
290
393
|
else:
|
291
394
|
rng = random.Random(0x172381 ^ step)
|
292
|
-
if
|
293
|
-
return utils.precond_schedule(step, group[
|
395
|
+
if "precond_scheduler" in group:
|
396
|
+
return utils.precond_schedule(step, group["precond_scheduler"], rng)
|
294
397
|
if prob is not None:
|
295
398
|
return utils.psgd_should_update(group, prob, rng, name=name)
|
296
399
|
raise ValueError("No preconditioner update schedule specified.")
|
@@ -313,14 +416,14 @@ def nesterov_momentum(group, updates, grads, params, momentum):
|
|
313
416
|
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
314
417
|
|
315
418
|
|
316
|
-
@zero_guard(
|
419
|
+
@zero_guard("momentum")
|
317
420
|
@no_state
|
318
421
|
def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
|
319
422
|
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
|
320
423
|
|
321
424
|
|
322
425
|
def _store_std(state, group, update, grad, param):
|
323
|
-
state[
|
426
|
+
state["init_std"] = torch.std(grad, dim=0)
|
324
427
|
|
325
428
|
|
326
429
|
@general_guard("init_std", init_fn=_store_std)
|
@@ -338,25 +441,39 @@ def heavyball_momentum(group, updates, grads, params, momentum):
|
|
338
441
|
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
|
339
442
|
|
340
443
|
|
341
|
-
_optim_fns = {
|
444
|
+
_optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
|
342
445
|
|
343
446
|
|
344
447
|
@zero_guard("exp_avg", "exp_avg_sq")
|
345
448
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
346
449
|
@no_state
|
347
|
-
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str =
|
450
|
+
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
|
348
451
|
update = utils.promote(update) # Promote to highest precision if needed
|
349
452
|
|
350
453
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
351
454
|
fn = _optim_fns[inner]
|
352
|
-
precond = fn(
|
353
|
-
|
455
|
+
precond = fn(
|
456
|
+
exp_avg,
|
457
|
+
exp_avg_sq,
|
458
|
+
grad_projected,
|
459
|
+
utils.get_beta1(group),
|
460
|
+
utils.get_beta2(group),
|
461
|
+
group["step"] - 1,
|
462
|
+
group["eps"],
|
463
|
+
)
|
354
464
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
355
465
|
|
356
466
|
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
357
|
-
utils.update_preconditioner(
|
358
|
-
|
359
|
-
|
467
|
+
utils.update_preconditioner(
|
468
|
+
u,
|
469
|
+
q,
|
470
|
+
gg,
|
471
|
+
ea,
|
472
|
+
group["max_precond_dim"],
|
473
|
+
group["precondition_1d"],
|
474
|
+
utils.beta_debias(group["shampoo_beta"], group["step"]),
|
475
|
+
group["is_preconditioning"],
|
476
|
+
)
|
360
477
|
return precond
|
361
478
|
|
362
479
|
|
@@ -364,17 +481,24 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
364
481
|
if prob is None:
|
365
482
|
prob = utils.precond_update_prob_schedule()
|
366
483
|
|
367
|
-
if not group[
|
484
|
+
if not group["is_preconditioning"]:
|
368
485
|
return Q_mat
|
369
486
|
|
370
|
-
utils.psgd_update_precond(
|
371
|
-
|
372
|
-
|
487
|
+
utils.psgd_update_precond(
|
488
|
+
Q_mat,
|
489
|
+
exprs,
|
490
|
+
getattr(param, "hessian_vector", grad),
|
491
|
+
group["precond_lr"],
|
492
|
+
Q,
|
493
|
+
group["store_triu_as_line"],
|
494
|
+
getattr(param, "vector", None),
|
495
|
+
)
|
496
|
+
if hasattr(param, "vector"):
|
373
497
|
del param.vector
|
374
498
|
del param.hessian_vector
|
375
499
|
|
376
500
|
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
377
|
-
if group[
|
501
|
+
if group["store_triu_as_line"]:
|
378
502
|
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
379
503
|
else:
|
380
504
|
utils.psgd_balance_Q(Q)
|
@@ -382,8 +506,8 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
382
506
|
if isinstance(prob, float):
|
383
507
|
float_prob = prob
|
384
508
|
else:
|
385
|
-
float_prob = prob(group.get(f
|
386
|
-
group[
|
509
|
+
float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
|
510
|
+
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
|
387
511
|
|
388
512
|
if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
|
389
513
|
return _update_psgd_cache(cached, Q_cache, Q_mat)
|
@@ -403,51 +527,169 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
403
527
|
|
404
528
|
|
405
529
|
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
|
406
|
-
if group.get(
|
407
|
-
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group[
|
408
|
-
|
409
|
-
|
530
|
+
if group.get("is_cached", False):
|
531
|
+
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad)
|
532
|
+
else:
|
533
|
+
out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad)
|
534
|
+
group["caution"] = False # we already cautioned here - shouldn't do it again
|
410
535
|
return out
|
411
536
|
|
412
537
|
|
413
538
|
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
|
414
|
-
if group.get(
|
415
|
-
utils.fused_precond_grad_cached_(
|
416
|
-
|
539
|
+
if group.get("is_cached", False):
|
540
|
+
utils.fused_precond_grad_cached_(
|
541
|
+
cache_expr,
|
542
|
+
update,
|
543
|
+
param,
|
544
|
+
group["lr"],
|
545
|
+
grad,
|
546
|
+
group["weight_decay"],
|
547
|
+
group["caution"],
|
548
|
+
*Q_cache,
|
549
|
+
)
|
417
550
|
else:
|
418
|
-
utils.fused_psgd_precond_grad(
|
419
|
-
|
551
|
+
utils.fused_psgd_precond_grad(
|
552
|
+
exprs[-1],
|
553
|
+
update,
|
554
|
+
param,
|
555
|
+
group["lr"],
|
556
|
+
grad,
|
557
|
+
group["weight_decay"],
|
558
|
+
group["caution"],
|
559
|
+
*Q_mat,
|
560
|
+
)
|
561
|
+
|
562
|
+
|
563
|
+
def _update_lra(
|
564
|
+
group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
|
565
|
+
):
|
566
|
+
if not group["is_preconditioning"]:
|
567
|
+
return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
|
568
|
+
|
569
|
+
if hasattr(params[0], "hessian_vector") and params[0].hessian_vector is not None:
|
570
|
+
vector = utils.flatten([p.vector for p in params])
|
571
|
+
hessian_vector = utils.flatten([p.hessian_vector for p in params])
|
572
|
+
else:
|
573
|
+
vector, hessian_vector = utils.dampen_multiple(grads)
|
574
|
+
return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
|
575
|
+
|
576
|
+
|
577
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
578
|
+
@no_state
|
579
|
+
def scale_by_psgd_lra(group, update, grad, param, U, V, d):
|
580
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
|
581
|
+
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
582
|
+
|
583
|
+
|
584
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
585
|
+
@no_state
|
586
|
+
def update_by_psgd_lra(group, update, grad, param, U, V, d):
|
587
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
|
588
|
+
utils.apply_lra_update(param, update, u, v, d)
|
589
|
+
raise SkipUpdate
|
590
|
+
|
591
|
+
|
592
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
593
|
+
@no_state
|
594
|
+
def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
595
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
|
596
|
+
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
597
|
+
|
598
|
+
|
599
|
+
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
600
|
+
@no_state
|
601
|
+
def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
|
602
|
+
u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
|
603
|
+
utils.apply_lra_update(param, update, u, v, d)
|
604
|
+
raise SkipUpdate
|
420
605
|
|
421
606
|
|
422
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
607
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
423
608
|
@no_state_no_foreach
|
424
|
-
def scale_by_psgd(
|
425
|
-
|
609
|
+
def scale_by_psgd(
|
610
|
+
group,
|
611
|
+
update,
|
612
|
+
grad,
|
613
|
+
param,
|
614
|
+
Q,
|
615
|
+
exprs,
|
616
|
+
Q_cache,
|
617
|
+
cache_expr: str,
|
618
|
+
cached: bool = False,
|
619
|
+
prob: Optional[callable] = None,
|
620
|
+
):
|
426
621
|
update = update.to(memory_format=torch.contiguous_format)
|
427
|
-
Q_mat = utils.line_to_triu(Q) if group[
|
428
|
-
Q_mat = _update_psgd_precond(
|
429
|
-
|
622
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
623
|
+
Q_mat = _update_psgd_precond(
|
624
|
+
cached,
|
625
|
+
Q_cache,
|
626
|
+
group,
|
627
|
+
param,
|
628
|
+
update if group["momentum_into_precond_update"] else grad,
|
629
|
+
Q_mat,
|
630
|
+
Q,
|
631
|
+
exprs,
|
632
|
+
prob,
|
633
|
+
)
|
430
634
|
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
431
635
|
|
432
636
|
|
433
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
637
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
434
638
|
@no_state_no_foreach
|
435
|
-
def scale_by_delayed_psgd(
|
436
|
-
|
437
|
-
|
639
|
+
def scale_by_delayed_psgd(
|
640
|
+
group,
|
641
|
+
update,
|
642
|
+
grad,
|
643
|
+
param,
|
644
|
+
Q,
|
645
|
+
exprs,
|
646
|
+
Q_cache,
|
647
|
+
cache_expr: str,
|
648
|
+
cached: bool = False,
|
649
|
+
prob: Optional[callable] = None,
|
650
|
+
):
|
651
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
438
652
|
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
439
|
-
_ = _update_psgd_precond(
|
440
|
-
|
653
|
+
_ = _update_psgd_precond(
|
654
|
+
cached,
|
655
|
+
Q_cache,
|
656
|
+
group,
|
657
|
+
param,
|
658
|
+
update if group["momentum_into_precond_update"] else grad,
|
659
|
+
Q_mat,
|
660
|
+
Q,
|
661
|
+
exprs,
|
662
|
+
prob,
|
663
|
+
)
|
441
664
|
return precond
|
442
665
|
|
443
666
|
|
444
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
667
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
445
668
|
@no_state_no_foreach
|
446
|
-
def update_by_psgd(
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
669
|
+
def update_by_psgd(
|
670
|
+
group,
|
671
|
+
update,
|
672
|
+
grad,
|
673
|
+
param,
|
674
|
+
Q,
|
675
|
+
exprs,
|
676
|
+
Q_cache,
|
677
|
+
cache_expr: str,
|
678
|
+
cached: bool = False,
|
679
|
+
prob: Optional[callable] = None,
|
680
|
+
):
|
681
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
682
|
+
Q_mat = _update_psgd_precond(
|
683
|
+
cached,
|
684
|
+
Q_cache,
|
685
|
+
group,
|
686
|
+
param,
|
687
|
+
update if group["momentum_into_precond_update"] else grad,
|
688
|
+
Q_mat,
|
689
|
+
Q,
|
690
|
+
exprs,
|
691
|
+
prob,
|
692
|
+
)
|
451
693
|
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
452
694
|
raise SkipUpdate
|
453
695
|
|
@@ -457,20 +699,39 @@ def sign(group, update, grad, param, graft: bool = True):
|
|
457
699
|
return utils.sign_(update, graft)
|
458
700
|
|
459
701
|
|
460
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=
|
702
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
|
461
703
|
@no_state_no_foreach
|
462
|
-
def update_by_delayed_psgd(
|
463
|
-
|
464
|
-
|
704
|
+
def update_by_delayed_psgd(
|
705
|
+
group,
|
706
|
+
update,
|
707
|
+
grad,
|
708
|
+
param,
|
709
|
+
Q,
|
710
|
+
exprs,
|
711
|
+
Q_cache,
|
712
|
+
cache_expr: str,
|
713
|
+
cached: bool = False,
|
714
|
+
prob: Optional[callable] = None,
|
715
|
+
):
|
716
|
+
Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
|
465
717
|
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
466
|
-
_ = _update_psgd_precond(
|
467
|
-
|
718
|
+
_ = _update_psgd_precond(
|
719
|
+
cached,
|
720
|
+
Q_cache,
|
721
|
+
group,
|
722
|
+
param,
|
723
|
+
update if group["momentum_into_precond_update"] else grad,
|
724
|
+
Q_mat,
|
725
|
+
Q,
|
726
|
+
exprs,
|
727
|
+
prob,
|
728
|
+
)
|
468
729
|
raise SkipUpdate
|
469
730
|
|
470
731
|
|
471
732
|
def palm_beta2(state, group, update, grad, param):
|
472
|
-
beta2 = 1 - group[
|
473
|
-
group[
|
733
|
+
beta2 = 1 - group["step"] ** -group["beta2_scale"]
|
734
|
+
group["betas"] = (utils.get_beta1(group), beta2)
|
474
735
|
return update
|
475
736
|
|
476
737
|
|
@@ -499,7 +760,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
499
760
|
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
500
761
|
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
|
501
762
|
if not skip_update and update is not None:
|
502
|
-
utils.update_param_(param, update, group[
|
763
|
+
utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
|
503
764
|
|
504
765
|
|
505
766
|
def create_branch(branches: List[List[callable]], merge_fn: callable):
|
@@ -524,14 +785,16 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
524
785
|
self.fns = tuple(fns)
|
525
786
|
|
526
787
|
def _step(self, group):
|
527
|
-
if
|
528
|
-
group[
|
529
|
-
if
|
530
|
-
utils.warn_once(
|
531
|
-
|
532
|
-
|
788
|
+
if "base_lr" not in group:
|
789
|
+
group["base_lr"] = group["lr"]
|
790
|
+
if "prev_lr" in group and group["prev_lr"] != group["lr"]:
|
791
|
+
utils.warn_once(
|
792
|
+
f"Learning rate changed between steps. This is an experimental feature and "
|
793
|
+
f"only supported with foreach=True (currently foreach={group['foreach']})."
|
794
|
+
)
|
795
|
+
group["base_lr"] = group["lr"]
|
533
796
|
|
534
|
-
caution = group[
|
797
|
+
caution = group["caution"]
|
535
798
|
|
536
799
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
537
800
|
|
@@ -541,26 +804,26 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
541
804
|
|
542
805
|
for param in p:
|
543
806
|
state = self.state_(param)
|
544
|
-
if
|
545
|
-
step = state[
|
807
|
+
if "step" in state:
|
808
|
+
step = state["step"]
|
546
809
|
elif self.compile_step:
|
547
810
|
step = utils.scalar_guard(0, param)
|
548
811
|
else:
|
549
812
|
step = 0
|
550
813
|
break
|
551
814
|
|
552
|
-
group[
|
553
|
-
group[
|
815
|
+
group["step"] = state["step"] = step = step + 1
|
816
|
+
group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
|
554
817
|
|
555
|
-
if not group[
|
818
|
+
if not group["foreach"] or len(p) == 1:
|
556
819
|
for param, grad in zip(p, g):
|
557
820
|
chain(self.state_, group, [grad], [param], *self.fns)
|
558
821
|
else:
|
559
822
|
chain(self.state_, group, g, p, *self.fns)
|
560
823
|
|
561
|
-
group[
|
562
|
-
group[
|
563
|
-
group[
|
824
|
+
group["caution"] = caution
|
825
|
+
group["lr"] = group["prev_lr"]
|
826
|
+
group["step"] = None
|
564
827
|
|
565
828
|
|
566
829
|
use_default = object()
|
@@ -571,7 +834,13 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
|
|
571
834
|
name = default(name, default_val)
|
572
835
|
if callable(name):
|
573
836
|
return name
|
574
|
-
elif name not in (
|
837
|
+
elif name not in (
|
838
|
+
"l2_clip_",
|
839
|
+
"rmsnorm_clip_",
|
840
|
+
"trust_region_clip_",
|
841
|
+
"a_law_compress",
|
842
|
+
"mu_law_compress",
|
843
|
+
):
|
575
844
|
raise ValueError(f"Clipping function {name} not found")
|
576
845
|
return getattr(utils, name)
|
577
846
|
|
@@ -581,16 +850,24 @@ def default(a, b):
|
|
581
850
|
|
582
851
|
|
583
852
|
# not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
|
584
|
-
_scale_to_update_map = {
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
853
|
+
_scale_to_update_map = {
|
854
|
+
scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
|
855
|
+
scale_by_psgd.get_fn(): update_by_psgd, #
|
856
|
+
scale_by_psgd_lra.get_fn(): update_by_psgd_lra, #
|
857
|
+
scale_by_delayed_psgd_lra.get_fn(): update_by_delayed_psgd_lra, #
|
858
|
+
scale_by_adam.get_fn(): update_by_adam, #
|
859
|
+
scale_by_laprop.get_fn(): update_by_laprop, #
|
860
|
+
scale_by_adopt.get_fn(): update_by_adopt, #
|
861
|
+
}
|
862
|
+
_scale_to_update_map_inv = {
|
863
|
+
update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
|
864
|
+
update_by_psgd.get_fn(): scale_by_psgd, #
|
865
|
+
update_by_psgd_lra.get_fn(): scale_by_psgd_lra, #
|
866
|
+
update_by_delayed_psgd_lra.get_fn(): scale_by_delayed_psgd_lra, #
|
867
|
+
update_by_adam.get_fn(): scale_by_adam, #
|
868
|
+
update_by_laprop.get_fn(): scale_by_laprop, #
|
869
|
+
update_by_adopt.get_fn(): scale_by_adopt, #
|
870
|
+
}
|
594
871
|
|
595
872
|
|
596
873
|
class BaseOpt(ChainOpt):
|
@@ -622,8 +899,18 @@ class BaseOpt(ChainOpt):
|
|
622
899
|
palm: bool = False
|
623
900
|
auto_fuse: bool = True
|
624
901
|
|
625
|
-
def __init__(
|
626
|
-
|
902
|
+
def __init__(
|
903
|
+
self,
|
904
|
+
params,
|
905
|
+
defaults,
|
906
|
+
foreach: bool,
|
907
|
+
gradient_clipping: str_or_fn,
|
908
|
+
update_clipping: str_or_fn,
|
909
|
+
palm: bool = use_default,
|
910
|
+
*fns,
|
911
|
+
compile_step: bool = use_default,
|
912
|
+
promote: bool = use_default,
|
913
|
+
):
|
627
914
|
if not fns:
|
628
915
|
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
629
916
|
|
@@ -643,8 +930,10 @@ class BaseOpt(ChainOpt):
|
|
643
930
|
fns = tuple(fns)[:-1] + (fn,)
|
644
931
|
elif fn in _scale_to_update_map_inv:
|
645
932
|
if not self.auto_fuse:
|
646
|
-
raise ValueError(
|
647
|
-
|
933
|
+
raise ValueError(
|
934
|
+
"update_clipping is currently not compatible with update_by_* functions. "
|
935
|
+
"Manually select scale_by_* functions or set auto_fuse=True."
|
936
|
+
)
|
648
937
|
fn = _scale_to_update_map_inv[fn]
|
649
938
|
if args is not None:
|
650
939
|
fn = functools.partial(fn, *args, **kwargs)
|
@@ -665,27 +954,27 @@ class BaseOpt(ChainOpt):
|
|
665
954
|
class ScheduleFree(BaseOpt):
|
666
955
|
def eval(self):
|
667
956
|
for group in self.param_groups:
|
668
|
-
group[
|
957
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
669
958
|
beta1 = utils.get_beta1(group)
|
670
959
|
if beta1 > 0 and not train_mode:
|
671
|
-
for p in group[
|
960
|
+
for p in group["params"]:
|
672
961
|
state = self.state_(p)
|
673
|
-
if
|
962
|
+
if "z" in state:
|
674
963
|
# Set p.data to x
|
675
|
-
z = utils.promote(state[
|
964
|
+
z = utils.promote(state["z"])
|
676
965
|
p32 = utils.promote(p.data)
|
677
966
|
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
678
967
|
utils.copy_stochastic_(p.data, p32)
|
679
968
|
|
680
969
|
def train(self):
|
681
970
|
for group in self.param_groups:
|
682
|
-
group[
|
971
|
+
group["train_mode"] = train_mode = not group.get("train_mode")
|
683
972
|
beta1 = utils.get_beta1(group)
|
684
973
|
if beta1 > 0 and train_mode:
|
685
|
-
for p in group[
|
974
|
+
for p in group["params"]:
|
686
975
|
state = self.state_(p)
|
687
|
-
if
|
688
|
-
z = utils.promote(state[
|
976
|
+
if "z" in state:
|
977
|
+
z = utils.promote(state["z"])
|
689
978
|
p32 = utils.promote(p.data)
|
690
979
|
p32.lerp_(end=z, weight=1 - beta1)
|
691
980
|
utils.copy_stochastic_(p.data, p32)
|