heavyball 1.6.3__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,8 +1,10 @@
1
1
  import functools
2
+ import math
2
3
  import random
3
- from typing import Optional, Union, Literal, List
4
+ from typing import List, Literal, Optional, Union
4
5
 
5
6
  import torch
7
+ from torch import Tensor
6
8
 
7
9
  from . import utils
8
10
 
@@ -42,7 +44,7 @@ class FunctionTransform:
42
44
  raise NotImplementedError
43
45
 
44
46
  def get_fn(self):
45
- if hasattr(self.fn, 'get_fn'):
47
+ if utils.hasattr_none(self.fn, "get_fn"):
46
48
  return self.fn.get_fn()
47
49
  return self.fn
48
50
 
@@ -55,7 +57,7 @@ def _zero_guard(state, key, ref, dtype):
55
57
 
56
58
 
57
59
  def _storage_dtype(group):
58
- dtype = group.get('storage_dtype', "float32")
60
+ dtype = group.get("storage_dtype", "float32")
59
61
  return getattr(torch, dtype)
60
62
 
61
63
 
@@ -65,8 +67,10 @@ class ZeroGuard(FunctionTransform):
65
67
  self.names = names
66
68
 
67
69
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
68
- vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
69
- for name in self.names]
70
+ vars = [
71
+ [_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
72
+ for name in self.names
73
+ ]
70
74
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
71
75
 
72
76
 
@@ -78,8 +82,10 @@ class CopyGuard(FunctionTransform):
78
82
 
79
83
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
80
84
  val = [update, grad, param, *args][self.index]
81
- vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
82
- for name in self.names]
85
+ vars = [
86
+ [_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
87
+ for name in self.names
88
+ ]
83
89
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
84
90
 
85
91
 
@@ -152,145 +158,243 @@ def exp_avg(group, update, grad, param, exp_avg):
152
158
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
153
159
 
154
160
 
155
- @zero_guard('exp_avg')
161
+ @zero_guard("exp_avg")
156
162
  @no_state
157
163
  def weight_decay_to_ema(group, update, grad, param, exp_avg):
158
- utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
159
- group['weight_decay_to_ema'] * group['lr'])
164
+ utils.weight_decay_to_ema_(
165
+ exp_avg,
166
+ update,
167
+ utils.beta_debias(group["ema_beta"], group["step"]),
168
+ group["weight_decay_to_ema"] * group["lr"],
169
+ )
160
170
  return update
161
171
 
162
172
 
163
- @zero_guard('exp_avg')
173
+ @zero_guard("exp_avg")
164
174
  @no_state
165
175
  def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
166
- utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
167
- group['weight_decay_to_ema'] * group['lr'])
176
+ utils.l1_weight_decay_to_ema_(
177
+ exp_avg,
178
+ update,
179
+ utils.beta_debias(group["ema_beta"], group["step"]),
180
+ group["weight_decay_to_ema"] * group["lr"],
181
+ )
168
182
  return update
169
183
 
170
184
 
171
185
  @zero_guard("exp_avg_sq")
172
186
  @no_state
173
187
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
174
- return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
175
- group['eps'])
188
+ return utils.scale_by_exp_avg_sq_(
189
+ exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"]
190
+ )
176
191
 
177
192
 
178
193
  @zero_guard("exp_avg", "exp_avg_sq")
179
194
  @no_state
180
195
  def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
181
- return utils.adam_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'], #
182
- group['eps'])
196
+ return utils.adam_(
197
+ exp_avg,
198
+ exp_avg_sq,
199
+ update,
200
+ utils.get_beta1(group),
201
+ utils.get_beta2(group),
202
+ group["step"], #
203
+ group["eps"],
204
+ )
183
205
 
184
206
 
185
207
  @zero_guard("exp_avg", "exp_avg_sq")
186
208
  @no_state
187
209
  def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
188
- utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
189
- group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution'])
210
+ utils.fused_adam_(
211
+ param,
212
+ exp_avg,
213
+ exp_avg_sq,
214
+ update,
215
+ grad,
216
+ utils.get_beta1(group),
217
+ utils.get_beta2(group),
218
+ group["step"],
219
+ group["lr"],
220
+ group["eps"],
221
+ group["weight_decay"],
222
+ group["caution"],
223
+ )
190
224
  raise SkipUpdate
191
225
 
192
226
 
193
227
  @zero_guard("exp_avg", "exp_avg_sq")
194
228
  @no_state
195
229
  def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
196
- return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'])
230
+ return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group["step"])
197
231
 
198
232
 
199
233
  @zero_guard("exp_avg", "exp_avg_sq")
200
234
  @no_state
201
235
  def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
202
- utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
203
- group['step'], group['lr'], group['weight_decay'], group['caution'])
236
+ utils.fused_laprop_(
237
+ param,
238
+ exp_avg,
239
+ exp_avg_sq,
240
+ update,
241
+ grad,
242
+ utils.get_beta1(group),
243
+ utils.get_beta2(group),
244
+ group["step"],
245
+ group["lr"],
246
+ group["weight_decay"],
247
+ group["caution"],
248
+ )
204
249
  raise SkipUpdate
205
250
 
206
251
 
207
252
  @no_state
208
253
  def orthogonalize_grad_to_param(group, update, grad, param):
209
- return utils.orthogonalize_grad_to_param(param, update, group['eps'])
254
+ return utils.orthogonalize_grad_to_param(param, update, group["eps"])
210
255
 
211
256
 
212
257
  @copy_guard(2, "z")
213
258
  @no_state
214
259
  def update_by_schedule_free(group, update, grad, param, z):
215
- group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0),
216
- utils.get_beta1(group), param, z, update, grad, group['caution'],
217
- group['r'], group['step'], group['weight_decay'])
260
+ group["weight_sum"] = utils.schedule_free_(
261
+ group["lr"],
262
+ group["weight_lr_power"],
263
+ group.get("weight_sum", 0),
264
+ utils.get_beta1(group),
265
+ param,
266
+ z,
267
+ update,
268
+ grad,
269
+ group["caution"],
270
+ group["r"],
271
+ group["step"],
272
+ group["weight_decay"],
273
+ )
218
274
  raise SkipUpdate
219
275
 
220
276
 
221
277
  @zero_guard("exp_avg", "exp_avg_sq")
222
278
  @no_state
223
279
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
224
- if group['step'] == 1:
225
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
280
+ if group["step"] == 1:
281
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
226
282
  raise SkipUpdate
227
283
 
228
- if group['step'] == 2:
284
+ if group["step"] == 2:
229
285
  update = utils.promote(update)
230
286
  easq = utils.promote(exp_avg_sq)
231
- [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
232
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
233
- group['eps'])
287
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
288
+ utils.scale_by_exp_avg_sq_(
289
+ exp_avg_sq,
290
+ update,
291
+ utils.beta_debias(utils.get_beta2(group), group["step"]),
292
+ group["eps"],
293
+ )
234
294
  raise SkipUpdate
235
295
 
236
- utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
237
- group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
296
+ utils.fused_adopt_(
297
+ param,
298
+ update,
299
+ grad,
300
+ exp_avg_sq,
301
+ exp_avg,
302
+ utils.get_beta1(group),
303
+ utils.get_beta2(group),
304
+ group["step"] - 2,
305
+ group["lr"],
306
+ group["eps"],
307
+ group["weight_decay"],
308
+ group["caution"],
309
+ )
238
310
  raise SkipUpdate
239
311
 
240
312
 
241
313
  @zero_guard("exp_avg", "exp_avg_sq")
242
314
  @no_state
243
315
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
244
- if group['step'] == 1:
245
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
316
+ if group["step"] == 1:
317
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
246
318
  raise SkipUpdate
247
319
 
248
- if group['step'] == 2:
320
+ if group["step"] == 2:
249
321
  update = utils.promote(update)
250
322
  easq = utils.promote(exp_avg_sq)
251
- [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
252
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
253
- group['eps'])
323
+ [utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)]
324
+ utils.scale_by_exp_avg_sq_(
325
+ exp_avg_sq,
326
+ update,
327
+ utils.beta_debias(utils.get_beta2(group), group["step"]),
328
+ group["eps"],
329
+ )
254
330
  raise SkipUpdate
255
331
 
256
- return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
257
-
258
-
259
- def _init_soap(state, group, update, grad, param, inner: str = ''):
260
- utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
261
-
262
-
263
- def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
264
- Q, state["exprs"] = utils.init_Q_exprs(grad, group['precond_init_scale'], group['max_size_triangular'],
265
- group['min_ndim_triangular'], group['memory_save_mode'],
266
- dtype=getattr(torch, group['q_dtype']))
267
- state["Q"] = utils.triu_to_line(Q) if group['store_triu_as_line'] else Q
332
+ return utils.adopt(
333
+ update,
334
+ exp_avg_sq,
335
+ exp_avg,
336
+ utils.get_beta1(group),
337
+ utils.get_beta2(group),
338
+ group["step"] - 2,
339
+ )
340
+
341
+
342
+ def _init_soap(state, group, update, grad, param, inner: str = ""):
343
+ utils.init_preconditioner(grad, state, group["max_precond_dim"], group["precondition_1d"])
344
+
345
+
346
+ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
347
+ Q, state["exprs"] = utils.init_Q_exprs(
348
+ grad,
349
+ group["precond_init_scale"],
350
+ group["precond_init_scale_scale"],
351
+ group["max_size_triangular"],
352
+ group["min_ndim_triangular"],
353
+ group["memory_save_mode"],
354
+ getattr(param, "hessian_vector", None),
355
+ getattr(param, "vector", None),
356
+ dtype=getattr(torch, group["q_dtype"]),
357
+ )
358
+ state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
268
359
 
269
360
  if not cached:
270
361
  return
271
362
 
272
- state['Q_cache'] = [torch.empty_like(q) for q in Q]
363
+ state["Q_cache"] = [torch.empty_like(q) for q in Q]
273
364
 
274
- expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
275
- expr = ','.join(expr)
276
- grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape))
277
- out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
278
- expr = f'{expr},{grad_expr}->{out_expr}'
365
+ expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
366
+ expr = ",".join(expr)
367
+ grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
368
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
369
+ expr = f"{expr},{grad_expr}->{out_expr}"
279
370
 
280
- state['cache_expr'] = expr
371
+ state["cache_expr"] = expr
281
372
 
282
373
 
283
- def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = 'cumulative_prob'):
284
- step = group['step']
285
- if 'precondition_frequency' in group:
286
- return step > 0 and step % group['precondition_frequency'] == 0
374
+ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
375
+ state["U"], state["V"], state["d"] = utils.init_lra(
376
+ grad,
377
+ group["precond_init_scale"],
378
+ group["precond_init_scale_scale"],
379
+ group["rank"],
380
+ getattr(param, "hessian_vector", None),
381
+ getattr(param, "vector", None),
382
+ dtype=getattr(torch, group["q_dtype"]),
383
+ )
384
+ group["preconditioning_step"] = 0
385
+
386
+
387
+ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
388
+ step = group["step"]
389
+ if "precondition_frequency" in group:
390
+ return step > 0 and step % group["precondition_frequency"] == 0
287
391
  if isinstance(step, torch.Tensor):
288
392
  utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
289
393
  rng = random.Random(0x172381)
290
394
  else:
291
395
  rng = random.Random(0x172381 ^ step)
292
- if 'precond_scheduler' in group:
293
- return utils.precond_schedule(step, group['precond_scheduler'], rng)
396
+ if "precond_scheduler" in group:
397
+ return utils.precond_schedule(step, group["precond_scheduler"], rng)
294
398
  if prob is not None:
295
399
  return utils.psgd_should_update(group, prob, rng, name=name)
296
400
  raise ValueError("No preconditioner update schedule specified.")
@@ -313,17 +417,17 @@ def nesterov_momentum(group, updates, grads, params, momentum):
313
417
  return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
314
418
 
315
419
 
316
- @zero_guard('momentum')
420
+ @zero_guard("momentum")
317
421
  @no_state
318
422
  def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
319
423
  return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
320
424
 
321
425
 
322
426
  def _store_std(state, group, update, grad, param):
323
- state['init_std'] = torch.std(grad, dim=0)
427
+ state["init_std"] = torch.std(grad, dim=0)
324
428
 
325
429
 
326
- @general_guard("init_std", init_fn=_store_std)
430
+ @general_guard("init_std", init_fn=_store_std, skip_first=False)
327
431
  @no_state
328
432
  def mup_approx(group, updates, grads, params, init_std):
329
433
  _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
@@ -332,31 +436,79 @@ def mup_approx(group, updates, grads, params, init_std):
332
436
  return updates
333
437
 
334
438
 
439
+ def _init_delta(state, group, update, grad, param, log_space: bool):
440
+ val = group["initial_d"]
441
+ state["delta"] = torch.full((), math.log(val) if log_space else val, dtype=param.dtype, device=param.device)
442
+
443
+
444
+ def _init_full_delta(state, group, update, grad, param, log_space: bool):
445
+ val = group["initial_d"]
446
+ state["delta"] = torch.full_like(param, math.log(val) if log_space else val)
447
+
448
+
449
+ @zero_guard("state")
450
+ @general_guard("delta", init_fn=functools.partial(_init_delta, log_space=False), skip_first=False)
451
+ @no_state
452
+ def scale_by_d_adaptation(group, update, grad, param, state, delta):
453
+ utils.d_adaptation(grad, update, state, delta)
454
+ return update
455
+
456
+
457
+ @zero_guard("state")
458
+ @general_guard("delta", init_fn=functools.partial(_init_delta, log_space=True), skip_first=False)
459
+ @no_state
460
+ def scale_by_lr_adaptation(group, update, grad, param, state, delta):
461
+ utils.lr_adaptation(grad, update, state, delta, group["lr_lr"])
462
+ return update
463
+
464
+
465
+ @zero_guard("state")
466
+ @general_guard("delta", init_fn=functools.partial(_init_full_delta, log_space=True), skip_first=False)
467
+ @no_state
468
+ def scale_by_pointwise_lr_adaptation(group, update, grad, param, state, delta):
469
+ utils.pointwise_lr_adaptation(grad, update, state, delta, group["lr_lr"])
470
+ return update
471
+
472
+
335
473
  @zero_guard("momentum")
336
474
  @no_state
337
475
  def heavyball_momentum(group, updates, grads, params, momentum):
338
476
  return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
339
477
 
340
478
 
341
- _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
479
+ _optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
342
480
 
343
481
 
344
482
  @zero_guard("exp_avg", "exp_avg_sq")
345
483
  @general_guard("Q", "GG", init_fn=_init_soap)
346
484
  @no_state
347
- def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
485
+ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
348
486
  update = utils.promote(update) # Promote to highest precision if needed
349
487
 
350
488
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
351
489
  fn = _optim_fns[inner]
352
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
353
- group['eps'])
490
+ precond = fn(
491
+ exp_avg,
492
+ exp_avg_sq,
493
+ grad_projected,
494
+ utils.get_beta1(group),
495
+ utils.get_beta2(group),
496
+ group["step"] - 1,
497
+ group["eps"],
498
+ )
354
499
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
355
500
 
356
501
  for u, q, gg, ea in zip(update, Q, GG, exp_avg):
357
- utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
358
- utils.beta_debias(group['shampoo_beta'], group['step']),
359
- group['is_preconditioning'])
502
+ utils.update_preconditioner(
503
+ u,
504
+ q,
505
+ gg,
506
+ ea,
507
+ group["max_precond_dim"],
508
+ group["precondition_1d"],
509
+ utils.beta_debias(group["shampoo_beta"], group["step"]),
510
+ group["is_preconditioning"],
511
+ )
360
512
  return precond
361
513
 
362
514
 
@@ -364,17 +516,28 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
364
516
  if prob is None:
365
517
  prob = utils.precond_update_prob_schedule()
366
518
 
367
- if not group['is_preconditioning']:
519
+ if not group["is_preconditioning"]:
368
520
  return Q_mat
369
521
 
370
- utils.psgd_update_precond(Q_mat, exprs, getattr(param, 'hessian_vector', grad), group['precond_lr'], Q,
371
- group['store_triu_as_line'], getattr(param, 'vector', None))
372
- if hasattr(param, 'vector'):
522
+ if utils.hasattr_none(param, "vector"):
523
+ vector, hessian_vector = param.vector, param.hessian_vector
373
524
  del param.vector
374
525
  del param.hessian_vector
526
+ else:
527
+ vector, hessian_vector = utils.dampen_grad(grad)
528
+
529
+ utils.psgd_update_precond(
530
+ Q_mat,
531
+ exprs,
532
+ hessian_vector,
533
+ group["precond_lr"],
534
+ Q,
535
+ group["store_triu_as_line"],
536
+ vector,
537
+ )
375
538
 
376
539
  if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
377
- if group['store_triu_as_line']:
540
+ if group["store_triu_as_line"]:
378
541
  utils.psgd_balance_Q([q_ for _, q_ in Q])
379
542
  else:
380
543
  utils.psgd_balance_Q(Q)
@@ -382,8 +545,8 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
382
545
  if isinstance(prob, float):
383
546
  float_prob = prob
384
547
  else:
385
- float_prob = prob(group.get(f'cumulative_prob_{id(Q)}_prob_step', 1))
386
- group['is_cached'] = should_use_cache = cached and float_prob < 0.5
548
+ float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
549
+ group["is_cached"] = should_use_cache = cached and float_prob < 0.5
387
550
 
388
551
  if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
389
552
  return _update_psgd_cache(cached, Q_cache, Q_mat)
@@ -403,51 +566,172 @@ def _update_psgd_cache(cached, Q_cache, q):
403
566
 
404
567
 
405
568
  def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
406
- if group.get('is_cached', False):
407
- out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
408
- out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
409
- group['caution'] = False # we already cautioned here - shouldn't do it again
569
+ if group.get("is_cached", False):
570
+ out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad)
571
+ else:
572
+ out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad)
573
+ group["caution"] = False # we already cautioned here - shouldn't do it again
410
574
  return out
411
575
 
412
576
 
413
577
  def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
414
- if group.get('is_cached', False):
415
- utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
416
- group['caution'], *Q_cache)
578
+ if group.get("is_cached", False):
579
+ utils.fused_precond_grad_cached_(
580
+ cache_expr,
581
+ update,
582
+ param,
583
+ group["lr"],
584
+ grad,
585
+ group["weight_decay"],
586
+ group["caution"],
587
+ *Q_cache,
588
+ )
417
589
  else:
418
- utils.fused_psgd_precond_grad(exprs[-1], update, param, group['lr'], grad, group['weight_decay'],
419
- group['caution'], *Q_mat)
590
+ utils.fused_psgd_precond_grad(
591
+ exprs[-1],
592
+ update,
593
+ param,
594
+ group["lr"],
595
+ grad,
596
+ group["weight_decay"],
597
+ group["caution"],
598
+ *Q_mat,
599
+ )
600
+
601
+
602
+ def _update_lra(
603
+ group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
604
+ ):
605
+ if not group["is_preconditioning"]:
606
+ return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
607
+
608
+ if utils.hasattr_none(params[0], "hessian_vector"):
609
+ vector = utils.flatten([p.vector for p in params])
610
+ hessian_vector = utils.flatten([p.hessian_vector for p in params])
611
+ for p in params:
612
+ del p.vector
613
+ del p.hessian_vector
614
+ else:
615
+ vector, hessian_vector = utils.dampen_multiple(grads)
616
+ return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
617
+
618
+
619
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
620
+ @no_state
621
+ def scale_by_psgd_lra(group, update, grad, param, U, V, d):
622
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
623
+ return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
624
+
625
+
626
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
627
+ @no_state
628
+ def update_by_psgd_lra(group, update, grad, param, U, V, d):
629
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
630
+ utils.apply_lra_update(param, update, u, v, d)
631
+ raise SkipUpdate
632
+
633
+
634
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
635
+ @no_state
636
+ def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
637
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
638
+ return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
639
+
640
+
641
+ @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
642
+ @no_state
643
+ def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
644
+ u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
645
+ utils.apply_lra_update(param, update, u, v, d)
646
+ raise SkipUpdate
420
647
 
421
648
 
422
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
649
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
423
650
  @no_state_no_foreach
424
- def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
425
- prob: Optional[callable] = None):
651
+ def scale_by_psgd(
652
+ group,
653
+ update,
654
+ grad,
655
+ param,
656
+ Q,
657
+ exprs,
658
+ Q_cache,
659
+ cache_expr: str,
660
+ cached: bool = False,
661
+ prob: Optional[callable] = None,
662
+ ):
426
663
  update = update.to(memory_format=torch.contiguous_format)
427
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
428
- Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
429
- update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
664
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
665
+ Q_mat = _update_psgd_precond(
666
+ cached,
667
+ Q_cache,
668
+ group,
669
+ param,
670
+ update if group["momentum_into_precond_update"] else grad,
671
+ Q_mat,
672
+ Q,
673
+ exprs,
674
+ prob,
675
+ )
430
676
  return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
431
677
 
432
678
 
433
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
679
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
434
680
  @no_state_no_foreach
435
- def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
436
- prob: Optional[callable] = None):
437
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
681
+ def scale_by_delayed_psgd(
682
+ group,
683
+ update,
684
+ grad,
685
+ param,
686
+ Q,
687
+ exprs,
688
+ Q_cache,
689
+ cache_expr: str,
690
+ cached: bool = False,
691
+ prob: Optional[callable] = None,
692
+ ):
693
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
438
694
  precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
439
- _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
440
- Q_mat, Q, exprs, prob)
695
+ _ = _update_psgd_precond(
696
+ cached,
697
+ Q_cache,
698
+ group,
699
+ param,
700
+ update if group["momentum_into_precond_update"] else grad,
701
+ Q_mat,
702
+ Q,
703
+ exprs,
704
+ prob,
705
+ )
441
706
  return precond
442
707
 
443
708
 
444
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
709
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
445
710
  @no_state_no_foreach
446
- def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
447
- prob: Optional[callable] = None):
448
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
449
- Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
450
- update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
711
+ def update_by_psgd(
712
+ group,
713
+ update,
714
+ grad,
715
+ param,
716
+ Q,
717
+ exprs,
718
+ Q_cache,
719
+ cache_expr: str,
720
+ cached: bool = False,
721
+ prob: Optional[callable] = None,
722
+ ):
723
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
724
+ Q_mat = _update_psgd_precond(
725
+ cached,
726
+ Q_cache,
727
+ group,
728
+ param,
729
+ update if group["momentum_into_precond_update"] else grad,
730
+ Q_mat,
731
+ Q,
732
+ exprs,
733
+ prob,
734
+ )
451
735
  _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
452
736
  raise SkipUpdate
453
737
 
@@ -457,20 +741,39 @@ def sign(group, update, grad, param, graft: bool = True):
457
741
  return utils.sign_(update, graft)
458
742
 
459
743
 
460
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
744
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
461
745
  @no_state_no_foreach
462
- def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
463
- prob: Optional[callable] = None):
464
- Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
746
+ def update_by_delayed_psgd(
747
+ group,
748
+ update,
749
+ grad,
750
+ param,
751
+ Q,
752
+ exprs,
753
+ Q_cache,
754
+ cache_expr: str,
755
+ cached: bool = False,
756
+ prob: Optional[callable] = None,
757
+ ):
758
+ Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
465
759
  _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
466
- _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
467
- Q_mat, Q, exprs, prob)
760
+ _ = _update_psgd_precond(
761
+ cached,
762
+ Q_cache,
763
+ group,
764
+ param,
765
+ update if group["momentum_into_precond_update"] else grad,
766
+ Q_mat,
767
+ Q,
768
+ exprs,
769
+ prob,
770
+ )
468
771
  raise SkipUpdate
469
772
 
470
773
 
471
774
  def palm_beta2(state, group, update, grad, param):
472
- beta2 = 1 - group['step'] ** -group['beta2_scale']
473
- group['betas'] = (utils.get_beta1(group), beta2)
775
+ beta2 = 1 - group["step"] ** -group["beta2_scale"]
776
+ group["betas"] = (utils.get_beta1(group), beta2)
474
777
  return update
475
778
 
476
779
 
@@ -499,7 +802,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
499
802
  update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
500
803
  update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
501
804
  if not skip_update and update is not None:
502
- utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
805
+ utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
503
806
 
504
807
 
505
808
  def create_branch(branches: List[List[callable]], merge_fn: callable):
@@ -524,14 +827,16 @@ class ChainOpt(utils.StatefulOptimizer):
524
827
  self.fns = tuple(fns)
525
828
 
526
829
  def _step(self, group):
527
- if 'base_lr' not in group:
528
- group['base_lr'] = group['lr']
529
- if 'prev_lr' in group and group['prev_lr'] != group['lr']:
530
- utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and '
531
- f'only supported with foreach=True (currently foreach={group["foreach"]}).')
532
- group['base_lr'] = group['lr']
830
+ if "base_lr" not in group:
831
+ group["base_lr"] = group["lr"]
832
+ if "prev_lr" in group and group["prev_lr"] != group["lr"]:
833
+ utils.warn_once(
834
+ f"Learning rate changed between steps. This is an experimental feature and "
835
+ f"only supported with foreach=True (currently foreach={group['foreach']})."
836
+ )
837
+ group["base_lr"] = group["lr"]
533
838
 
534
- caution = group['caution']
839
+ caution = group["caution"]
535
840
 
536
841
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
537
842
 
@@ -541,26 +846,26 @@ class ChainOpt(utils.StatefulOptimizer):
541
846
 
542
847
  for param in p:
543
848
  state = self.state_(param)
544
- if 'step' in state:
545
- step = state['step']
849
+ if "step" in state:
850
+ step = state["step"]
546
851
  elif self.compile_step:
547
852
  step = utils.scalar_guard(0, param)
548
853
  else:
549
854
  step = 0
550
855
  break
551
856
 
552
- group['step'] = state['step'] = step = step + 1
553
- group['prev_lr'] = group['lr'] = group['base_lr'] * step / max(step, group['warmup_steps'] + 1)
857
+ group["step"] = state["step"] = step = step + 1
858
+ group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
554
859
 
555
- if not group['foreach'] or len(p) == 1:
860
+ if not group["foreach"] or len(p) == 1:
556
861
  for param, grad in zip(p, g):
557
862
  chain(self.state_, group, [grad], [param], *self.fns)
558
863
  else:
559
864
  chain(self.state_, group, g, p, *self.fns)
560
865
 
561
- group['caution'] = caution
562
- group['lr'] = group['prev_lr']
563
- group['step'] = None
866
+ group["caution"] = caution
867
+ group["lr"] = group["prev_lr"]
868
+ group["step"] = None
564
869
 
565
870
 
566
871
  use_default = object()
@@ -571,7 +876,13 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
571
876
  name = default(name, default_val)
572
877
  if callable(name):
573
878
  return name
574
- elif name not in ('l2_clip_', 'rmsnorm_clip_', 'trust_region_clip_', 'a_law_compress', 'mu_law_compress'):
879
+ elif name not in (
880
+ "l2_clip_",
881
+ "rmsnorm_clip_",
882
+ "trust_region_clip_",
883
+ "a_law_compress",
884
+ "mu_law_compress",
885
+ ):
575
886
  raise ValueError(f"Clipping function {name} not found")
576
887
  return getattr(utils, name)
577
888
 
@@ -581,16 +892,24 @@ def default(a, b):
581
892
 
582
893
 
583
894
  # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
584
- _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
585
- scale_by_psgd.get_fn(): update_by_psgd, #
586
- scale_by_adam.get_fn(): update_by_adam, #
587
- scale_by_laprop.get_fn(): update_by_laprop, #
588
- scale_by_adopt.get_fn(): update_by_adopt}
589
- _scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
590
- update_by_psgd.get_fn(): scale_by_psgd, #
591
- update_by_adam.get_fn(): scale_by_adam, #
592
- update_by_laprop.get_fn(): scale_by_laprop, #
593
- update_by_adopt.get_fn(): scale_by_adopt}
895
+ _scale_to_update_map = {
896
+ scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
897
+ scale_by_psgd.get_fn(): update_by_psgd, #
898
+ scale_by_psgd_lra.get_fn(): update_by_psgd_lra, #
899
+ scale_by_delayed_psgd_lra.get_fn(): update_by_delayed_psgd_lra, #
900
+ scale_by_adam.get_fn(): update_by_adam, #
901
+ scale_by_laprop.get_fn(): update_by_laprop, #
902
+ scale_by_adopt.get_fn(): update_by_adopt, #
903
+ }
904
+ _scale_to_update_map_inv = {
905
+ update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
906
+ update_by_psgd.get_fn(): scale_by_psgd, #
907
+ update_by_psgd_lra.get_fn(): scale_by_psgd_lra, #
908
+ update_by_delayed_psgd_lra.get_fn(): scale_by_delayed_psgd_lra, #
909
+ update_by_adam.get_fn(): scale_by_adam, #
910
+ update_by_laprop.get_fn(): scale_by_laprop, #
911
+ update_by_adopt.get_fn(): scale_by_adopt, #
912
+ }
594
913
 
595
914
 
596
915
  class BaseOpt(ChainOpt):
@@ -622,8 +941,18 @@ class BaseOpt(ChainOpt):
622
941
  palm: bool = False
623
942
  auto_fuse: bool = True
624
943
 
625
- def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
626
- palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
944
+ def __init__(
945
+ self,
946
+ params,
947
+ defaults,
948
+ foreach: bool,
949
+ gradient_clipping: str_or_fn,
950
+ update_clipping: str_or_fn,
951
+ palm: bool = use_default,
952
+ *fns,
953
+ compile_step: bool = use_default,
954
+ promote: bool = use_default,
955
+ ):
627
956
  if not fns:
628
957
  raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
629
958
 
@@ -643,8 +972,10 @@ class BaseOpt(ChainOpt):
643
972
  fns = tuple(fns)[:-1] + (fn,)
644
973
  elif fn in _scale_to_update_map_inv:
645
974
  if not self.auto_fuse:
646
- raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
647
- "Manually select scale_by_* functions or set auto_fuse=True.")
975
+ raise ValueError(
976
+ "update_clipping is currently not compatible with update_by_* functions. "
977
+ "Manually select scale_by_* functions or set auto_fuse=True."
978
+ )
648
979
  fn = _scale_to_update_map_inv[fn]
649
980
  if args is not None:
650
981
  fn = functools.partial(fn, *args, **kwargs)
@@ -665,27 +996,27 @@ class BaseOpt(ChainOpt):
665
996
  class ScheduleFree(BaseOpt):
666
997
  def eval(self):
667
998
  for group in self.param_groups:
668
- group['train_mode'] = train_mode = not group.get('train_mode')
999
+ group["train_mode"] = train_mode = not group.get("train_mode")
669
1000
  beta1 = utils.get_beta1(group)
670
1001
  if beta1 > 0 and not train_mode:
671
- for p in group['params']:
1002
+ for p in group["params"]:
672
1003
  state = self.state_(p)
673
- if 'z' in state:
1004
+ if "z" in state:
674
1005
  # Set p.data to x
675
- z = utils.promote(state['z'])
1006
+ z = utils.promote(state["z"])
676
1007
  p32 = utils.promote(p.data)
677
1008
  p32.lerp_(end=z, weight=1 - 1 / beta1)
678
1009
  utils.copy_stochastic_(p.data, p32)
679
1010
 
680
1011
  def train(self):
681
1012
  for group in self.param_groups:
682
- group['train_mode'] = train_mode = not group.get('train_mode')
1013
+ group["train_mode"] = train_mode = not group.get("train_mode")
683
1014
  beta1 = utils.get_beta1(group)
684
1015
  if beta1 > 0 and train_mode:
685
- for p in group['params']:
1016
+ for p in group["params"]:
686
1017
  state = self.state_(p)
687
- if 'z' in state:
688
- z = utils.promote(state['z'])
1018
+ if "z" in state:
1019
+ z = utils.promote(state["z"])
689
1020
  p32 = utils.promote(p.data)
690
1021
  p32.lerp_(end=z, weight=1 - beta1)
691
1022
  utils.copy_stochastic_(p.data, p32)