heavyball 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import functools
2
2
  import random
3
- from typing import Optional, Union, Literal, List
3
+ from typing import List, Literal, Optional, Union
4
4
 
5
5
  import torch
6
+ from torch import Tensor
6
7
 
7
8
  from . import utils
8
9
 
@@ -42,7 +43,7 @@ class FunctionTransform:
42
43
  raise NotImplementedError
43
44
 
44
45
  def get_fn(self):
45
- if hasattr(self.fn, 'get_fn'):
46
+ if hasattr(self.fn, "get_fn"):
46
47
  return self.fn.get_fn()
47
48
  return self.fn
48
49
 
@@ -55,7 +56,7 @@ def _zero_guard(state, key, ref, dtype):
55
56
 
56
57
 
57
58
  def _storage_dtype(group):
58
- dtype = group.get('storage_dtype', "float32")
59
+ dtype = group.get("storage_dtype", "float32")
59
60
  return getattr(torch, dtype)
60
61
 
61
62
 
@@ -65,8 +66,10 @@ class ZeroGuard(FunctionTransform):
65
66
  self.names = names
66
67
 
67
68
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
68
- vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
69
- for name in self.names]
69
+ vars = [
70
+ [_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
71
+ for name in self.names
72
+ ]
70
73
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
71
74
 
72
75
 
@@ -78,8 +81,10 @@ class CopyGuard(FunctionTransform):
78
81
 
79
82
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
80
83
  val = [update, grad, param, *args][self.index]
81
- vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
82
- for name in self.names]
84
+ vars = [
85
+ [_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
86
+ for name in self.names
87
+ ]
83
88
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
84
89
 
85
90
 
@@ -152,145 +157,243 @@ def exp_avg(group, update, grad, param, exp_avg):
152
157
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
153
158
 
154
159
 
155
- @zero_guard('exp_avg')
160
+ @zero_guard("exp_avg")
156
161
  @no_state
157
162
  def weight_decay_to_ema(group, update, grad, param, exp_avg):
158
- utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
159
- group['weight_decay_to_ema'] * group['lr'])
163
+ utils.weight_decay_to_ema_(
164
+ exp_avg,
165
+ update,
166
+ utils.beta_debias(group["ema_beta"], group["step"]),
167
+ group["weight_decay_to_ema"] * group["lr"],
168
+ )
160
169
  return update
161
170
 
162
171
 
163
- @zero_guard('exp_avg')
172
+ @zero_guard("exp_avg")
164
173
  @no_state
165
174
  def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
166
- utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
167
- group['weight_decay_to_ema'] * group['lr'])
175
+ utils.l1_weight_decay_to_ema_(
176
+ exp_avg,
177
+ update,
178
+ utils.beta_debias(group["ema_beta"], group["step"]),
179
+ group["weight_decay_to_ema"] * group["lr"],
180
+ )
168
181
  return update
169
182
 
170
183
 
171
184
  @zero_guard("exp_avg_sq")
172
185
  @no_state
173
186
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
174
- return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
175
- group['eps'])
187
+ return utils.scale_by_exp_avg_sq_(
188
+ exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"]
189
+ )
176
190
 
177
191
 
178
192
  @zero_guard("exp_avg", "exp_avg_sq")
179
193
  @no_state
180
194
  def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
181
- return utils.adam_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'], #
182
- group['eps'])
195
+ return utils.adam_(
196
+ exp_avg,
197
+ exp_avg_sq,
198
+ update,
199
+ utils.get_beta1(group),
200
+ utils.get_beta2(group),
201
+ group["step"], #
202
+ group["eps"],
203
+ )
183
204
 
184
205
 
185
206
  @zero_guard("exp_avg", "exp_avg_sq")
186
207
  @no_state
187
208
  def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
188
- utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
189
- group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution'])
209
+ utils.fused_adam_(
210
+ param,
211
+ exp_avg,
212
+ exp_avg_sq,
213
+ update,
214
+ grad,
215
+ utils.get_beta1(group),
216
+ utils.get_beta2(group),
217
+ group["step"],
218
+ group["lr"],
219
+ group["eps"],
220
+ group["weight_decay"],
221
+ group["caution"],
222
+ )
190
223
  raise SkipUpdate
191
224
 
192
225
 
193
226
  @zero_guard("exp_avg", "exp_avg_sq")
194
227
  @no_state
195
228
  def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
196
- return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'])
229
+ return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group["step"])
197
230
 
198
231
 
199
232
  @zero_guard("exp_avg", "exp_avg_sq")
200
233
  @no_state
201
234
  def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
202
- utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
203
- group['step'], group['lr'], group['weight_decay'], group['caution'])
235
+ utils.fused_laprop_(
236
+ param,
237
+ exp_avg,
238
+ exp_avg_sq,
239
+ update,
240
+ grad,
241
+ utils.get_beta1(group),
242
+ utils.get_beta2(group),
243
+ group["step"],
244
+ group["lr"],
245
+ group["weight_decay"],
246
+ group["caution"],
247
+ )
204
248
  raise SkipUpdate
205
249
 
206
250
 
207
251
  @no_state
208
252
  def orthogonalize_grad_to_param(group, update, grad, param):
209
- return utils.orthogonalize_grad_to_param(param, update, group['eps'])
253
+ return utils.orthogonalize_grad_to_param(param, update, group["eps"])
210
254
 
211
255
 
212
256
  @copy_guard(2, "z")
213
257
  @no_state
214
258
  def update_by_schedule_free(group, update, grad, param, z):
215
- group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
216
- utils.get_beta1(group), param, z, update, grad, group['caution'],
217
- group['r'], group['step'], group['weight_decay'])
259
+ group["weight_sum"] = utils.schedule_free_(
260
+ group["lr"],
261
+ group["weight_lr_power"],
262
+ group.get("weight_sum", 0),
263
+ utils.get_beta1(group),
264
+ param,
265
+ z,
266
+ update,
267
+ grad,
268
+ group["caution"],
269
+ group["r"],
270
+ group["step"],
271
+ group["weight_decay"],
272
+ )
218
273
  raise SkipUpdate
219
274
 
220
275
 
221
276
  @zero_guard("exp_avg", "exp_avg_sq")
222
277
  @no_state
223
278
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
224
- if group['step'] == 1:
225
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
279
+ if group["step"] == 1:
280
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
226
281
  raise SkipUpdate
227
282
 
228
- if group['step'] == 2:
283
+ if group["step"] == 2:
229
284
  update = utils.promote(update)
230
285
  easq = utils.promote(exp_avg_sq)
231
- [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
232
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
233
- group['eps'])
286
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
287
+ utils.scale_by_exp_avg_sq_(
288
+ exp_avg_sq,
289
+ update,
290
+ utils.beta_debias(utils.get_beta2(group), group["step"]),
291
+ group["eps"],
292
+ )
234
293
  raise SkipUpdate
235
294
 
236
- utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
237
- group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
295
+ utils.fused_adopt_(
296
+ param,
297
+ update,
298
+ grad,
299
+ exp_avg_sq,
300
+ exp_avg,
301
+ utils.get_beta1(group),
302
+ utils.get_beta2(group),
303
+ group["step"] - 2,
304
+ group["lr"],
305
+ group["eps"],
306
+ group["weight_decay"],
307
+ group["caution"],
308
+ )
238
309
  raise SkipUpdate
239
310
 
240
311
 
241
312
  @zero_guard("exp_avg", "exp_avg_sq")
242
313
  @no_state
243
314
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
244
- if group['step'] == 1:
245
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
315
+ if group["step"] == 1:
316
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
246
317
  raise SkipUpdate
247
318
 
248
- if group['step'] == 2:
319
+ if group["step"] == 2:
249
320
  update = utils.promote(update)
250
321
  easq = utils.promote(exp_avg_sq)
251
- [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
252
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
253
- group['eps'])
322
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
323
+ utils.scale_by_exp_avg_sq_(
324
+ exp_avg_sq,
325
+ update,
326
+ utils.beta_debias(utils.get_beta2(group), group["step"]),
327
+ group["eps"],
328
+ )
254
329
  raise SkipUpdate
255
330
 
256
- return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
257
-
258
-
259
- def _init_soap(state, group, update, grad, param, inner: str = ''):
260
- utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
261
-
262
-
263
- def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
264
- Q, state["exprs"] = utils.init_Q_exprs(grad, group['precond_init_scale'], group['max_size_triangular'],
265
- group['min_ndim_triangular'], group['memory_save_mode'],
266
- dtype=getattr(torch, group['q_dtype']))
267
- state["Q"] = utils.triu_to_line(Q) if group['store_triu_as_line'] else Q
331
+ return utils.adopt(
332
+ update,
333
+ exp_avg_sq,
334
+ exp_avg,
335
+ utils.get_beta1(group),
336
+ utils.get_beta2(group),
337
+ group["step"] - 2,
338
+ )
339
+
340
+
341
+ def _init_soap(state, group, update, grad, param, inner: str = ""):
342
+ utils.init_preconditioner(grad, state, group["max_precond_dim"], group["precondition_1d"])
343
+
344
+
345
+ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
346
+ Q, state["exprs"] = utils.init_Q_exprs(
347
+ grad,
348
+ group["precond_init_scale"],
349
+ group["precond_init_scale_scale"],
350
+ group["max_size_triangular"],
351
+ group["min_ndim_triangular"],
352
+ group["memory_save_mode"],
353
+ getattr(param, "hessian_vector", None),
354
+ getattr(param, "vector", None),
355
+ dtype=getattr(torch, group["q_dtype"]),
356
+ )
357
+ state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
268
358
 
269
359
  if not cached:
270
360
  return
271
361
 
272
- state['Q_cache'] = [torch.empty_like(q) for q in Q]
362
+ state["Q_cache"] = [torch.empty_like(q) for q in Q]
363
+
364
+ expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
365
+ expr = ",".join(expr)
366
+ grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
367
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
368
+ expr = f"{expr},{grad_expr}->{out_expr}"
273
369
 
274
- expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
275
- expr = ','.join(expr)
276
- grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape))
277
- out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
278
- expr = f'{expr},{grad_expr}->{out_expr}'
370
+ state["cache_expr"] = expr
279
371
 
280
- state['cache_expr'] = expr
281
372
 
373
+ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
374
+ state["U"], state["V"], state["d"] = utils.init_lra(
375
+ grad,
376
+ group["precond_init_scale"],
377
+ group["precond_init_scale_scale"],
378
+ group["rank"],
379
+ getattr(param, "hessian_vector", None),
380
+ getattr(param, "vector", None),
381
+ dtype=getattr(torch, group["q_dtype"]),
382
+ )
383
+ group["preconditioning_step"] = 0
282
384
 
283
- def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = 'cumulative_prob'):
284
- step = group['step']
285
- if 'precondition_frequency' in group:
286
- return step > 0 and step % group['precondition_frequency'] == 0
385
+
386
+ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
387
+ step = group["step"]
388
+ if "precondition_frequency" in group:
389
+ return step > 0 and step % group["precondition_frequency"] == 0
287
390
  if isinstance(step, torch.Tensor):
288
391
  utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
289
392
  rng = random.Random(0x172381)
290
393
  else:
291
394
  rng = random.Random(0x172381 ^ step)
292
- if 'precond_scheduler' in group:
293
- return utils.precond_schedule(step, group['precond_scheduler'], rng)
395
+ if "precond_scheduler" in group:
396
+ return utils.precond_schedule(step, group["precond_scheduler"], rng)
294
397
  if prob is not None:
295
398
  return utils.psgd_should_update(group, prob, rng, name=name)
296
399
  raise ValueError("No preconditioner update schedule specified.")
@@ -313,14 +416,14 @@ def nesterov_momentum(group, updates, grads, params, momentum):
313
416
  return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
314
417
 
315
418
 
316
- @zero_guard('momentum')
419
+ @zero_guard("momentum")
317
420
  @no_state
318
421
  def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
319
422
  return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
320
423
 
321
424
 
322
425
  def _store_std(state, group, update, grad, param):
323
- state['init_std'] = torch.std(grad, dim=0)
426
+ state["init_std"] = torch.std(grad, dim=0)
324
427
 
325
428
 
326
429
  @general_guard("init_std", init_fn=_store_std)
@@ -338,25 +441,39 @@ def heavyball_momentum(group, updates, grads, params, momentum):
338
441
  return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
339
442
 
340
443
 
341
- _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
444
+ _optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
342
445
 
343
446
 
344
447
  @zero_guard("exp_avg", "exp_avg_sq")
345
448
  @general_guard("Q", "GG", init_fn=_init_soap)
346
449
  @no_state
347
- def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
450
+ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
348
451
  update = utils.promote(update) # Promote to highest precision if needed
349
452
 
350
453
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
351
454
  fn = _optim_fns[inner]
352
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
353
- group['eps'])
455
+ precond = fn(
456
+ exp_avg,
457
+ exp_avg_sq,
458
+ grad_projected,
459
+ utils.get_beta1(group),
460
+ utils.get_beta2(group),
461
+ group["step"] - 1,
462
+ group["eps"],
463
+ )
354
464
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
355
465
 
356
466
  for u, q, gg, ea in zip(update, Q, GG, exp_avg):
357
- utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
358
- utils.beta_debias(group['shampoo_beta'], group['step']),
359
- group['is_preconditioning'])
467
+ utils.update_preconditioner(
468
+ u,
469
+ q,
470
+ gg,
471
+ ea,
472
+ group["max_precond_dim"],
473
+ group["precondition_1d"],
474
+ utils.beta_debias(group["shampoo_beta"], group["step"]),
475
+ group["is_preconditioning"],
476
+ )
360
477
  return precond
361
478
 
362
479
 
@@ -364,17 +481,24 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
364
481
  if prob is None:
365
482
  prob = utils.precond_update_prob_schedule()
366
483
 
367
- if not group['is_preconditioning']:
484
+ if not group["is_preconditioning"]:
368
485
  return Q_mat
369
486
 
370
- utils.psgd_update_precond(Q_mat, exprs, getattr(param, 'hessian_vector', grad), group['precond_lr'], Q,
371
- group['store_triu_as_line'], getattr(param, 'vector', None))
372
- if hasattr(param, 'vector'):
487
+ utils.psgd_update_precond(
488
+ Q_mat,
489
+ exprs,
490
+ getattr(param, "hessian_vector", grad),
491
+ group["precond_lr"],
492
+ Q,
493
+ group["store_triu_as_line"],
494
+ getattr(param, "vector", None),
495
+ )
496
+ if hasattr(param, "vector"):
373
497
  del param.vector
374
498
  del param.hessian_vector
375
499
 
376
500
  if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
377
- if group['store_triu_as_line']:
501
+ if group["store_triu_as_line"]:
378
502
  utils.psgd_balance_Q([q_ for _, q_ in Q])
379
503
  else:
380
504
  utils.psgd_balance_Q(Q)
@@ -382,8 +506,8 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
382
506
  if isinstance(prob, float):
383
507
  float_prob = prob
384
508
  else:
385
- float_prob = prob(group.get(f'cumulative_prob_{id(Q)}_prob_step', 1))
386
- group['is_cached'] = should_use_cache = cached and float_prob < 0.5
509
+ float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
510
+ group["is_cached"] = should_use_cache = cached and float_prob < 0.5
387
511
 
388
512
  if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
389
513
  return _update_psgd_cache(cached, Q_cache, Q_mat)
@@ -403,51 +527,169 @@ def _update_psgd_cache(cached, Q_cache, q):
403
527
 
404
528
 
405
529
  def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
406
- if group.get('is_cached', False):
407
- out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
408
- out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
409
- group['caution'] = False # we already cautioned here - shouldn't do it again
530
+ if group.get("is_cached", False):
531
+ out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad)
532
+ else:
533
+ out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad)
534
+ group["caution"] = False # we already cautioned here - shouldn't do it again
410
535
  return out
411
536
 
412
537
 
413
538
  def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
414
- if group.get('is_cached', False):
415
- utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
416
- group['caution'], *Q_cache)
539
+ if group.get("is_cached", False):
540
+ utils.fused_precond_grad_cached_(
541
+ cache_expr,
542
+ update,
543
+ param,
544
+ group["lr"],
545
+ grad,
546
+ group["weight_decay"],
547
+ group["caution"],
548
+ *Q_cache,
549
+ )
417
550
  else:
418
- utils.fused_psgd_precond_grad(exprs[-1], update, param, group['lr'], grad, group['weight_decay'],
419
- group['caution'], *Q_mat)
551
+ utils.fused_psgd_precond_grad(
552
+ exprs[-1],
553
+ update,
554
+ param,
555
+ group["lr"],
556
+ grad,
557
+ group["weight_decay"],
558
+ group["caution"],
559
+ *Q_mat,
560
+ )
561
+
562
+
563
+ def _update_lra(
564
+ group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
565
+ ):
566
+ if not group["is_preconditioning"]:
567
+ return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
568
+
569
+ if hasattr(params[0], "hessian_vector") and params[0].hessian_vector is not None:
570
+ vector = utils.flatten([p.vector for p in params])
571
+ hessian_vector = utils.flatten([p.hessian_vector for p in params])
572
+ else:
573
+ vector, hessian_vector = utils.dampen_multiple(grads)
574
+ return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
575
+
576
+
577
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
578
+ @no_state
579
+ def scale_by_psgd_lra(group, update, grad, param, U, V, d):
580
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
581
+ return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
582
+
583
+
584
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
585
+ @no_state
586
+ def update_by_psgd_lra(group, update, grad, param, U, V, d):
587
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
588
+ utils.apply_lra_update(param, update, u, v, d)
589
+ raise SkipUpdate
590
+
591
+
592
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
593
+ @no_state
594
+ def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
595
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
596
+ return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
597
+
598
+
599
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
600
+ @no_state
601
+ def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
602
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
603
+ utils.apply_lra_update(param, update, u, v, d)
604
+ raise SkipUpdate
420
605
 
421
606
 
422
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
607
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
423
608
  @no_state_no_foreach
424
- def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
425
- prob: Optional[callable] = None):
609
+ def scale_by_psgd(
610
+ group,
611
+ update,
612
+ grad,
613
+ param,
614
+ Q,
615
+ exprs,
616
+ Q_cache,
617
+ cache_expr: str,
618
+ cached: bool = False,
619
+ prob: Optional[callable] = None,
620
+ ):
426
621
  update = update.to(memory_format=torch.contiguous_format)
427
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
428
- Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
429
- update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
622
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
623
+ Q_mat = _update_psgd_precond(
624
+ cached,
625
+ Q_cache,
626
+ group,
627
+ param,
628
+ update if group["momentum_into_precond_update"] else grad,
629
+ Q_mat,
630
+ Q,
631
+ exprs,
632
+ prob,
633
+ )
430
634
  return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
431
635
 
432
636
 
433
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
637
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
434
638
  @no_state_no_foreach
435
- def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
436
- prob: Optional[callable] = None):
437
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
639
+ def scale_by_delayed_psgd(
640
+ group,
641
+ update,
642
+ grad,
643
+ param,
644
+ Q,
645
+ exprs,
646
+ Q_cache,
647
+ cache_expr: str,
648
+ cached: bool = False,
649
+ prob: Optional[callable] = None,
650
+ ):
651
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
438
652
  precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
439
- _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
440
- Q_mat, Q, exprs, prob)
653
+ _ = _update_psgd_precond(
654
+ cached,
655
+ Q_cache,
656
+ group,
657
+ param,
658
+ update if group["momentum_into_precond_update"] else grad,
659
+ Q_mat,
660
+ Q,
661
+ exprs,
662
+ prob,
663
+ )
441
664
  return precond
442
665
 
443
666
 
444
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
667
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
445
668
  @no_state_no_foreach
446
- def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
447
- prob: Optional[callable] = None):
448
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
449
- Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
450
- update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
669
+ def update_by_psgd(
670
+ group,
671
+ update,
672
+ grad,
673
+ param,
674
+ Q,
675
+ exprs,
676
+ Q_cache,
677
+ cache_expr: str,
678
+ cached: bool = False,
679
+ prob: Optional[callable] = None,
680
+ ):
681
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
682
+ Q_mat = _update_psgd_precond(
683
+ cached,
684
+ Q_cache,
685
+ group,
686
+ param,
687
+ update if group["momentum_into_precond_update"] else grad,
688
+ Q_mat,
689
+ Q,
690
+ exprs,
691
+ prob,
692
+ )
451
693
  _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
452
694
  raise SkipUpdate
453
695
 
@@ -457,20 +699,39 @@ def sign(group, update, grad, param, graft: bool = True):
457
699
  return utils.sign_(update, graft)
458
700
 
459
701
 
460
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
702
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
461
703
  @no_state_no_foreach
462
- def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
463
- prob: Optional[callable] = None):
464
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
704
+ def update_by_delayed_psgd(
705
+ group,
706
+ update,
707
+ grad,
708
+ param,
709
+ Q,
710
+ exprs,
711
+ Q_cache,
712
+ cache_expr: str,
713
+ cached: bool = False,
714
+ prob: Optional[callable] = None,
715
+ ):
716
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
465
717
  _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
466
- _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
467
- Q_mat, Q, exprs, prob)
718
+ _ = _update_psgd_precond(
719
+ cached,
720
+ Q_cache,
721
+ group,
722
+ param,
723
+ update if group["momentum_into_precond_update"] else grad,
724
+ Q_mat,
725
+ Q,
726
+ exprs,
727
+ prob,
728
+ )
468
729
  raise SkipUpdate
469
730
 
470
731
 
471
732
  def palm_beta2(state, group, update, grad, param):
472
- beta2 = 1 - group['step'] ** -group['beta2_scale']
473
- group['betas'] = (utils.get_beta1(group), beta2)
733
+ beta2 = 1 - group["step"] ** -group["beta2_scale"]
734
+ group["betas"] = (utils.get_beta1(group), beta2)
474
735
  return update
475
736
 
476
737
 
@@ -499,7 +760,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
499
760
  update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
500
761
  update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
501
762
  if not skip_update and update is not None:
502
- utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
763
+ utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
503
764
 
504
765
 
505
766
  def create_branch(branches: List[List[callable]], merge_fn: callable):
@@ -524,14 +785,16 @@ class ChainOpt(utils.StatefulOptimizer):
524
785
  self.fns = tuple(fns)
525
786
 
526
787
  def _step(self, group):
527
- if 'base_lr' not in group:
528
- group['base_lr'] = group['lr']
529
- if 'prev_lr' in group and group['prev_lr'] != group['lr']:
530
- utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and '
531
- f'only supported with foreach=True (currently foreach={group["foreach"]}).')
532
- group['base_lr'] = group['lr']
788
+ if "base_lr" not in group:
789
+ group["base_lr"] = group["lr"]
790
+ if "prev_lr" in group and group["prev_lr"] != group["lr"]:
791
+ utils.warn_once(
792
+ f"Learning rate changed between steps. This is an experimental feature and "
793
+ f"only supported with foreach=True (currently foreach={group['foreach']})."
794
+ )
795
+ group["base_lr"] = group["lr"]
533
796
 
534
- caution = group['caution']
797
+ caution = group["caution"]
535
798
 
536
799
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
537
800
 
@@ -541,26 +804,26 @@ class ChainOpt(utils.StatefulOptimizer):
541
804
 
542
805
  for param in p:
543
806
  state = self.state_(param)
544
- if 'step' in state:
545
- step = state['step']
807
+ if "step" in state:
808
+ step = state["step"]
546
809
  elif self.compile_step:
547
810
  step = utils.scalar_guard(0, param)
548
811
  else:
549
812
  step = 0
550
813
  break
551
814
 
552
- group['step'] = state['step'] = step = step + 1
553
- group['prev_lr'] = group['lr'] = group['base_lr'] * step / max(step, group['warmup_steps'] + 1)
815
+ group["step"] = state["step"] = step = step + 1
816
+ group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
554
817
 
555
- if not group['foreach'] or len(p) == 1:
818
+ if not group["foreach"] or len(p) == 1:
556
819
  for param, grad in zip(p, g):
557
820
  chain(self.state_, group, [grad], [param], *self.fns)
558
821
  else:
559
822
  chain(self.state_, group, g, p, *self.fns)
560
823
 
561
- group['caution'] = caution
562
- group['lr'] = group['prev_lr']
563
- group['step'] = None
824
+ group["caution"] = caution
825
+ group["lr"] = group["prev_lr"]
826
+ group["step"] = None
564
827
 
565
828
 
566
829
  use_default = object()
@@ -571,7 +834,13 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
571
834
  name = default(name, default_val)
572
835
  if callable(name):
573
836
  return name
574
- elif name not in ('l2_clip_', 'rmsnorm_clip_', 'trust_region_clip_', 'a_law_compress', 'mu_law_compress'):
837
+ elif name not in (
838
+ "l2_clip_",
839
+ "rmsnorm_clip_",
840
+ "trust_region_clip_",
841
+ "a_law_compress",
842
+ "mu_law_compress",
843
+ ):
575
844
  raise ValueError(f"Clipping function {name} not found")
576
845
  return getattr(utils, name)
577
846
 
@@ -581,16 +850,24 @@ def default(a, b):
581
850
 
582
851
 
583
852
  # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
584
- _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
585
- scale_by_psgd.get_fn(): update_by_psgd, #
586
- scale_by_adam.get_fn(): update_by_adam, #
587
- scale_by_laprop.get_fn(): update_by_laprop, #
588
- scale_by_adopt.get_fn(): update_by_adopt}
589
- _scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
590
- update_by_psgd.get_fn(): scale_by_psgd, #
591
- update_by_adam.get_fn(): scale_by_adam, #
592
- update_by_laprop.get_fn(): scale_by_laprop, #
593
- update_by_adopt.get_fn(): scale_by_adopt}
853
+ _scale_to_update_map = {
854
+ scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
855
+ scale_by_psgd.get_fn(): update_by_psgd, #
856
+ scale_by_psgd_lra.get_fn(): update_by_psgd_lra, #
857
+ scale_by_delayed_psgd_lra.get_fn(): update_by_delayed_psgd_lra, #
858
+ scale_by_adam.get_fn(): update_by_adam, #
859
+ scale_by_laprop.get_fn(): update_by_laprop, #
860
+ scale_by_adopt.get_fn(): update_by_adopt, #
861
+ }
862
+ _scale_to_update_map_inv = {
863
+ update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
864
+ update_by_psgd.get_fn(): scale_by_psgd, #
865
+ update_by_psgd_lra.get_fn(): scale_by_psgd_lra, #
866
+ update_by_delayed_psgd_lra.get_fn(): scale_by_delayed_psgd_lra, #
867
+ update_by_adam.get_fn(): scale_by_adam, #
868
+ update_by_laprop.get_fn(): scale_by_laprop, #
869
+ update_by_adopt.get_fn(): scale_by_adopt, #
870
+ }
594
871
 
595
872
 
596
873
  class BaseOpt(ChainOpt):
@@ -622,8 +899,18 @@ class BaseOpt(ChainOpt):
622
899
  palm: bool = False
623
900
  auto_fuse: bool = True
624
901
 
625
- def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
626
- palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
902
+ def __init__(
903
+ self,
904
+ params,
905
+ defaults,
906
+ foreach: bool,
907
+ gradient_clipping: str_or_fn,
908
+ update_clipping: str_or_fn,
909
+ palm: bool = use_default,
910
+ *fns,
911
+ compile_step: bool = use_default,
912
+ promote: bool = use_default,
913
+ ):
627
914
  if not fns:
628
915
  raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
629
916
 
@@ -643,8 +930,10 @@ class BaseOpt(ChainOpt):
643
930
  fns = tuple(fns)[:-1] + (fn,)
644
931
  elif fn in _scale_to_update_map_inv:
645
932
  if not self.auto_fuse:
646
- raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
647
- "Manually select scale_by_* functions or set auto_fuse=True.")
933
+ raise ValueError(
934
+ "update_clipping is currently not compatible with update_by_* functions. "
935
+ "Manually select scale_by_* functions or set auto_fuse=True."
936
+ )
648
937
  fn = _scale_to_update_map_inv[fn]
649
938
  if args is not None:
650
939
  fn = functools.partial(fn, *args, **kwargs)
@@ -665,27 +954,27 @@ class BaseOpt(ChainOpt):
665
954
  class ScheduleFree(BaseOpt):
666
955
  def eval(self):
667
956
  for group in self.param_groups:
668
- group['train_mode'] = train_mode = not group.get('train_mode')
957
+ group["train_mode"] = train_mode = not group.get("train_mode")
669
958
  beta1 = utils.get_beta1(group)
670
959
  if beta1 > 0 and not train_mode:
671
- for p in group['params']:
960
+ for p in group["params"]:
672
961
  state = self.state_(p)
673
- if 'z' in state:
962
+ if "z" in state:
674
963
  # Set p.data to x
675
- z = utils.promote(state['z'])
964
+ z = utils.promote(state["z"])
676
965
  p32 = utils.promote(p.data)
677
966
  p32.lerp_(end=z, weight=1 - 1 / beta1)
678
967
  utils.copy_stochastic_(p.data, p32)
679
968
 
680
969
  def train(self):
681
970
  for group in self.param_groups:
682
- group['train_mode'] = train_mode = not group.get('train_mode')
971
+ group["train_mode"] = train_mode = not group.get("train_mode")
683
972
  beta1 = utils.get_beta1(group)
684
973
  if beta1 > 0 and train_mode:
685
- for p in group['params']:
974
+ for p in group["params"]:
686
975
  state = self.state_(p)
687
- if 'z' in state:
688
- z = utils.promote(state['z'])
976
+ if "z" in state:
977
+ z = utils.promote(state["z"])
689
978
  p32 = utils.promote(p.data)
690
979
  p32.lerp_(end=z, weight=1 - beta1)
691
980
  utils.copy_stochastic_(p.data, p32)