heavyball 1.7.2__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,14 +1,15 @@
1
+ import copy
1
2
  import functools
2
3
  import math
3
4
  import random
4
- from typing import List, Literal, Optional, Union
5
+ from typing import Iterable, List, Literal, Optional, Union
5
6
 
6
7
  import torch
7
8
  from torch import Tensor
8
9
 
9
10
  from . import utils
10
11
 
11
- balance_probability: float = 0.01
12
+ use_default = utils.use_default
12
13
 
13
14
 
14
15
  def _key_in_state(state, key):
@@ -36,20 +37,52 @@ def _guard_in_state(state, key, template_fn):
36
37
 
37
38
 
38
39
  class FunctionTransform:
39
- def __init__(self, fn):
40
+ def __init__(self, fn, names: list[str] | None = None):
41
+ if names is None:
42
+ names = []
40
43
  self.fn = fn
41
44
  self.fn_name = self.get_fn().__name__
45
+ self.transform_idx = None
46
+ self.is_initialized = False
47
+ self.names = names
42
48
 
43
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
49
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
44
50
  raise NotImplementedError
45
51
 
52
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
53
+ raise NotImplementedError
54
+
55
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
56
+ states = [state(p) for p in param]
57
+ skip_update = False
58
+ for st, a in zip(states, zip(update, grad, param, *args)):
59
+ if self.transform_idx not in st.get("is_initialized", set()):
60
+ try:
61
+ self._init(st, group, *a, **kwargs)
62
+ except SkipUpdate:
63
+ skip_update = True
64
+ except:
65
+ raise
66
+ finally:
67
+ if "is_initialized" not in st:
68
+ st["is_initialized"] = set()
69
+ st["is_initialized"].add(self.transform_idx)
70
+ if skip_update:
71
+ raise SkipUpdate from None
72
+ vars = [[st.get(self.val_name(name), None) for st in states] for name in self.names]
73
+ return self._call(state, group, update, grad, param, vars, *args, **kwargs)
74
+
46
75
  def get_fn(self):
47
76
  if utils.hasattr_none(self.fn, "get_fn"):
48
77
  return self.fn.get_fn()
49
78
  return self.fn
50
79
 
51
80
  def val_name(self, name):
52
- return f"{self.fn_name}_{name}"
81
+ assert self.transform_idx is not None
82
+ return f"{self.fn_name}_{name}_{self.transform_idx}"
83
+
84
+ def __repr__(self):
85
+ return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
53
86
 
54
87
 
55
88
  def _zero_guard(state, key, ref, dtype):
@@ -63,49 +96,102 @@ def _storage_dtype(group):
63
96
 
64
97
  class ZeroGuard(FunctionTransform):
65
98
  def __init__(self, fn, names):
66
- super().__init__(fn)
67
- self.names = names
99
+ super().__init__(fn, names)
68
100
 
69
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
70
- vars = [
71
- [_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
72
- for name in self.names
73
- ]
101
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
102
+ for name in self.names:
103
+ _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
104
+
105
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
74
106
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
75
107
 
76
108
 
109
+ class PrecondGradAccumGuard(FunctionTransform):
110
+ def __init__(self, fn):
111
+ super().__init__(fn, ["precond_grad_accum"])
112
+ self.steps_taken = 0
113
+ self.pass_through = None
114
+
115
+ def _accum(self, state, new):
116
+ self.steps_taken += 1
117
+ utils.stochastic_add_(state, new)
118
+
119
+ def _reset(self, state):
120
+ if self.steps_taken == 0:
121
+ self.steps_taken = 0
122
+ utils.zero_(state)
123
+
124
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
125
+ if self.pass_through is None:
126
+ self.pass_through = not group.get("precond_grad_accum", False)
127
+ if self.pass_through is False:
128
+ for name in self.names:
129
+ _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
130
+
131
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
132
+ base_grad = update if group.get("momentum_into_precond_update", True) else grad
133
+ if self.pass_through:
134
+ return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs)
135
+
136
+ (vars,) = vars
137
+ if group["is_preconditioning"]:
138
+ if self.steps_taken:
139
+ self._accum(vars, base_grad)
140
+ utils.stochastic_multiply_(vars, 1 / self.steps_taken)
141
+ else:
142
+ vars = base_grad
143
+ else:
144
+ self._accum(vars, base_grad)
145
+ vars = base_grad
146
+ try:
147
+ out = self.fn(state, group, update, grad, param, *args, vars, **kwargs)
148
+ finally:
149
+ if group["is_preconditioning"]:
150
+ self._reset(vars)
151
+
152
+ return out
153
+
154
+
77
155
  class CopyGuard(FunctionTransform):
78
156
  def __init__(self, fn, index, names):
79
- super().__init__(fn)
157
+ super().__init__(fn, names)
80
158
  self.index = index
81
- self.names = names
82
159
 
83
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
160
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
84
161
  val = [update, grad, param, *args][self.index]
85
- vars = [
86
- [_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
87
- for name in self.names
88
- ]
162
+ for name in self.names:
163
+ state[self.val_name(name)] = torch.clone(val)
164
+
165
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
89
166
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
90
167
 
91
168
 
92
- class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
169
+ class GeneralGuard(FunctionTransform):
93
170
  def __init__(self, fn, names, init_fn, skip_first: bool = True):
94
- super().__init__(fn)
95
- self.names = names
171
+ super().__init__(fn, names)
96
172
  self.init_fn = init_fn
97
173
  self.skip_first = skip_first
98
-
99
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
100
- vars = []
101
- skip_update = False
102
- for p, g, u in zip(param, grad, update):
103
- st = state(p)
104
- skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
105
- vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
106
- if skip_update and self.skip_first:
107
- raise SkipUpdate
108
- return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
174
+ self.named_to_anonymous = None
175
+ self.anonymous_to_named = None
176
+
177
+ def _map(self, state_fn, param, mapping):
178
+ for p in param:
179
+ state = state_fn(p)
180
+ for name, mapped in mapping.items():
181
+ if mapped in state:
182
+ raise ValueError(f"Name {name} already mapped to {mapped}")
183
+ if name in state:
184
+ state[mapped] = state.pop(name)
185
+
186
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
187
+ self.init_fn(state, group, update, grad, param, **kwargs)
188
+ for name in self.names:
189
+ state[self.val_name(name)] = state.pop(name, None)
190
+ if self.skip_first:
191
+ raise SkipUpdate from None
192
+
193
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
194
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
109
195
 
110
196
 
111
197
  class NoState(FunctionTransform):
@@ -124,7 +210,7 @@ class NoStateNoForeach(FunctionTransform):
124
210
  skip_update = True
125
211
  pass
126
212
  if skip_update:
127
- raise SkipUpdate
213
+ raise SkipUpdate from None
128
214
  return updates
129
215
 
130
216
 
@@ -158,12 +244,23 @@ def exp_avg(group, update, grad, param, exp_avg):
158
244
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
159
245
 
160
246
 
247
+ @copy_guard(2, "init")
248
+ @no_state
249
+ def weight_decay_to_init(group, update, grad, param, init):
250
+ utils.weight_decay_to_init_(
251
+ param,
252
+ init,
253
+ group["weight_decay_to_ema"] * group["lr"],
254
+ )
255
+ return update
256
+
257
+
161
258
  @zero_guard("exp_avg")
162
259
  @no_state
163
260
  def weight_decay_to_ema(group, update, grad, param, exp_avg):
164
261
  utils.weight_decay_to_ema_(
262
+ param,
165
263
  exp_avg,
166
- update,
167
264
  utils.beta_debias(group["ema_beta"], group["step"]),
168
265
  group["weight_decay_to_ema"] * group["lr"],
169
266
  )
@@ -174,8 +271,8 @@ def weight_decay_to_ema(group, update, grad, param, exp_avg):
174
271
  @no_state
175
272
  def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
176
273
  utils.l1_weight_decay_to_ema_(
274
+ param,
177
275
  exp_avg,
178
- update,
179
276
  utils.beta_debias(group["ema_beta"], group["step"]),
180
277
  group["weight_decay_to_ema"] * group["lr"],
181
278
  )
@@ -221,7 +318,7 @@ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
221
318
  group["weight_decay"],
222
319
  group["caution"],
223
320
  )
224
- raise SkipUpdate
321
+ raise SkipUpdate from None
225
322
 
226
323
 
227
324
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -246,7 +343,7 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
246
343
  group["weight_decay"],
247
344
  group["caution"],
248
345
  )
249
- raise SkipUpdate
346
+ raise SkipUpdate from None
250
347
 
251
348
 
252
349
  @no_state
@@ -271,7 +368,26 @@ def update_by_schedule_free(group, update, grad, param, z):
271
368
  group["step"],
272
369
  group["weight_decay"],
273
370
  )
274
- raise SkipUpdate
371
+ raise SkipUpdate from None
372
+
373
+
374
+ @copy_guard(2, "z")
375
+ @zero_guard("exp_avg")
376
+ @no_state
377
+ def update_by_msam(group, update, grad, param, z, exp_avg):
378
+ utils.msam_(
379
+ group["lr"],
380
+ utils.beta_debias(utils.get_beta1(group), group["step"]),
381
+ param,
382
+ z,
383
+ update,
384
+ grad,
385
+ exp_avg,
386
+ group["caution"],
387
+ group["weight_decay"],
388
+ group["sam_step_size"],
389
+ )
390
+ raise SkipUpdate from None
275
391
 
276
392
 
277
393
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -279,7 +395,7 @@ def update_by_schedule_free(group, update, grad, param, z):
279
395
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
280
396
  if group["step"] == 1:
281
397
  utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
282
- raise SkipUpdate
398
+ raise SkipUpdate from None
283
399
 
284
400
  if group["step"] == 2:
285
401
  update = utils.promote(update)
@@ -291,7 +407,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
291
407
  utils.beta_debias(utils.get_beta2(group), group["step"]),
292
408
  group["eps"],
293
409
  )
294
- raise SkipUpdate
410
+ raise SkipUpdate from None
295
411
 
296
412
  utils.fused_adopt_(
297
413
  param,
@@ -307,7 +423,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
307
423
  group["weight_decay"],
308
424
  group["caution"],
309
425
  )
310
- raise SkipUpdate
426
+ raise SkipUpdate from None
311
427
 
312
428
 
313
429
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -315,7 +431,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
315
431
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
316
432
  if group["step"] == 1:
317
433
  utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
318
- raise SkipUpdate
434
+ raise SkipUpdate from None
319
435
 
320
436
  if group["step"] == 2:
321
437
  update = utils.promote(update)
@@ -327,7 +443,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
327
443
  utils.beta_debias(utils.get_beta2(group), group["step"]),
328
444
  group["eps"],
329
445
  )
330
- raise SkipUpdate
446
+ raise SkipUpdate from None
331
447
 
332
448
  return utils.adopt(
333
449
  update,
@@ -344,10 +460,11 @@ def _init_soap(state, group, update, grad, param, inner: str = ""):
344
460
 
345
461
 
346
462
  def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
347
- Q, state["exprs"] = utils.init_Q_exprs(
463
+ Q = utils.init_Q_exprs(
348
464
  grad,
349
465
  group["precond_init_scale"],
350
466
  group["precond_init_scale_scale"],
467
+ group["precond_init_scale_power"],
351
468
  group["max_size_triangular"],
352
469
  group["min_ndim_triangular"],
353
470
  group["memory_save_mode"],
@@ -356,32 +473,27 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
356
473
  dtype=getattr(torch, group["q_dtype"]),
357
474
  )
358
475
  state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
359
-
476
+ state["running_lower_bound"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
477
+ if group["adaptive"]:
478
+ state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
360
479
  if not cached:
361
480
  return
362
481
 
363
482
  state["Q_cache"] = [torch.empty_like(q) for q in Q]
364
483
 
365
- expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
366
- expr = ",".join(expr)
367
- grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
368
- out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
369
- expr = f"{expr},{grad_expr}->{out_expr}"
370
-
371
- state["cache_expr"] = expr
372
-
373
484
 
374
485
  def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
375
486
  state["U"], state["V"], state["d"] = utils.init_lra(
376
487
  grad,
488
+ group["param_count"],
377
489
  group["precond_init_scale"],
378
490
  group["precond_init_scale_scale"],
491
+ group["precond_init_scale_power"],
379
492
  group["rank"],
380
493
  getattr(param, "hessian_vector", None),
381
494
  getattr(param, "vector", None),
382
495
  dtype=getattr(torch, group["q_dtype"]),
383
496
  )
384
- group["preconditioning_step"] = 0
385
497
 
386
498
 
387
499
  def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
@@ -402,12 +514,12 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
402
514
 
403
515
  @no_state_no_foreach
404
516
  def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
405
- if update.dim() == 1:
517
+ if update.dim() < 2:
406
518
  return update
407
519
  original_shape = update.shape
408
520
  # doing it this way, as tmp and update are not guaranteed to share memory address or layout
409
521
  tmp = update.flatten(1, -1)
410
- utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp, scale_mode)
522
+ utils.inplace_orthogonal_(tmp, out=tmp, scale_mode=scale_mode)
411
523
  return tmp.reshape(original_shape)
412
524
 
413
525
 
@@ -424,7 +536,7 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
424
536
 
425
537
 
426
538
  def _store_std(state, group, update, grad, param):
427
- state["init_std"] = torch.std(grad, dim=0)
539
+ state["init_std"] = torch.std(param)
428
540
 
429
541
 
430
542
  @general_guard("init_std", init_fn=_store_std, skip_first=False)
@@ -483,9 +595,7 @@ _optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
483
595
  @general_guard("Q", "GG", init_fn=_init_soap)
484
596
  @no_state
485
597
  def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
486
- update = utils.promote(update) # Promote to highest precision if needed
487
-
488
- grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
598
+ grad_projected = [utils.project(utils.promote(u), q, False) for u, q in zip(update, Q)]
489
599
  fn = _optim_fns[inner]
490
600
  precond = fn(
491
601
  exp_avg,
@@ -500,7 +610,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
500
610
 
501
611
  for u, q, gg, ea in zip(update, Q, GG, exp_avg):
502
612
  utils.update_preconditioner(
503
- u,
613
+ utils.promote(u),
504
614
  q,
505
615
  gg,
506
616
  ea,
@@ -512,35 +622,38 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
512
622
  return precond
513
623
 
514
624
 
515
- def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
625
+ def _update_psgd_precond(
626
+ cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, prob: Optional[callable] = None
627
+ ) -> Optional[Tensor]:
516
628
  if prob is None:
517
629
  prob = utils.precond_update_prob_schedule()
518
630
 
519
631
  if not group["is_preconditioning"]:
520
- return Q_mat
632
+ return
521
633
 
522
634
  if utils.hasattr_none(param, "vector"):
523
635
  vector, hessian_vector = param.vector, param.hessian_vector
524
636
  del param.vector
525
637
  del param.hessian_vector
638
+ elif group["inverse_free"]:
639
+ vector, hessian_vector = None, grad
526
640
  else:
527
- vector, hessian_vector = utils.dampen_grad(grad)
641
+ vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
528
642
 
529
- utils.psgd_update_precond(
530
- Q_mat,
531
- exprs,
643
+ precond = (utils.inverse_free_psgd_update_precond if vector is None else utils.psgd_update_precond)(
532
644
  hessian_vector,
533
645
  group["precond_lr"],
534
646
  Q,
535
647
  group["store_triu_as_line"],
648
+ velocity,
649
+ utils.beta_debias(utils.get_beta2(group), group["step"]),
650
+ group["ortho_method"],
536
651
  vector,
652
+ running_lower_bound,
653
+ group["lower_bound_beta"],
654
+ group["precond_update_power_iterations"],
537
655
  )
538
-
539
- if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
540
- if group["store_triu_as_line"]:
541
- utils.psgd_balance_Q([q_ for _, q_ in Q])
542
- else:
543
- utils.psgd_balance_Q(Q)
656
+ del vector, hessian_vector
544
657
 
545
658
  if isinstance(prob, float):
546
659
  float_prob = prob
@@ -548,54 +661,45 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
548
661
  float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
549
662
  group["is_cached"] = should_use_cache = cached and float_prob < 0.5
550
663
 
551
- if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
552
- return _update_psgd_cache(cached, Q_cache, Q_mat)
553
- return Q_mat
554
-
664
+ if precond is not None:
665
+ return precond
666
+ if not should_use_cache or not cached:
667
+ return None # caching adds extra ops and is not worth the overhead when we precondition at every step
555
668
 
556
- def _update_psgd_cache(cached, Q_cache, q):
557
- if not cached:
558
- return q
559
-
560
- for c_, q_ in zip(Q_cache, q):
669
+ for c_, q_ in zip(Q_cache, utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q):
561
670
  if q_.ndim == 2:
562
671
  torch.matmul(q_.T, q_, out=c_)
563
672
  else:
564
673
  torch.mul(q_, q_, out=c_)
565
- return Q_cache
674
+ return None
566
675
 
567
676
 
568
- def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
677
+ def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
678
+ kwargs = {"ea": update, "caution": group["caution"], "grad": grad}
569
679
  if group.get("is_cached", False):
570
- out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad)
680
+ out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
571
681
  else:
572
- out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad)
682
+ out = utils.psgd_precond_grad(
683
+ preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
684
+ )
573
685
  group["caution"] = False # we already cautioned here - shouldn't do it again
574
686
  return out
575
687
 
576
688
 
577
- def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
689
+ def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
690
+ kwargs = {
691
+ "ea": update,
692
+ "caution": group["caution"],
693
+ "grad": grad,
694
+ "param": param,
695
+ "lr": group["lr"],
696
+ "decay": group["weight_decay"],
697
+ }
578
698
  if group.get("is_cached", False):
579
- utils.fused_precond_grad_cached_(
580
- cache_expr,
581
- update,
582
- param,
583
- group["lr"],
584
- grad,
585
- group["weight_decay"],
586
- group["caution"],
587
- *Q_cache,
588
- )
699
+ utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
589
700
  else:
590
701
  utils.fused_psgd_precond_grad(
591
- exprs[-1],
592
- update,
593
- param,
594
- group["lr"],
595
- grad,
596
- group["weight_decay"],
597
- group["caution"],
598
- *Q_mat,
702
+ preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
599
703
  )
600
704
 
601
705
 
@@ -603,7 +707,7 @@ def _update_lra(
603
707
  group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
604
708
  ):
605
709
  if not group["is_preconditioning"]:
606
- return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
710
+ return utils.multi_flatten((U, 1), (V, 1), (d, 0))
607
711
 
608
712
  if utils.hasattr_none(params[0], "hessian_vector"):
609
713
  vector = utils.flatten([p.vector for p in params])
@@ -613,127 +717,109 @@ def _update_lra(
613
717
  del p.hessian_vector
614
718
  else:
615
719
  vector, hessian_vector = utils.dampen_multiple(grads)
616
- return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
720
+ precond_step = group["precond_step"] = group.get("precond_step", -1) + 1
721
+ return utils.update_lra_precond_(
722
+ U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed, bool(precond_step % 2)
723
+ )
617
724
 
618
725
 
726
+ @PrecondGradAccumGuard
619
727
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
620
728
  @no_state
621
- def scale_by_psgd_lra(group, update, grad, param, U, V, d):
622
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
729
+ def scale_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
730
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
623
731
  return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
624
732
 
625
733
 
734
+ @PrecondGradAccumGuard
626
735
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
627
736
  @no_state
628
- def update_by_psgd_lra(group, update, grad, param, U, V, d):
629
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
630
- utils.apply_lra_update(param, update, u, v, d)
631
- raise SkipUpdate
737
+ def update_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
738
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
739
+ utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
740
+ raise SkipUpdate from None
632
741
 
633
742
 
743
+ @PrecondGradAccumGuard
634
744
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
635
745
  @no_state
636
- def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
637
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
746
+ def scale_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
747
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
638
748
  return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
639
749
 
640
750
 
751
+ @PrecondGradAccumGuard
641
752
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
642
753
  @no_state
643
- def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
644
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
645
- utils.apply_lra_update(param, update, u, v, d)
646
- raise SkipUpdate
754
+ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
755
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
756
+ utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
757
+ raise SkipUpdate from None
647
758
 
648
759
 
649
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
760
+ @PrecondGradAccumGuard
761
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
650
762
  @no_state_no_foreach
651
763
  def scale_by_psgd(
652
764
  group,
653
765
  update,
654
766
  grad,
655
767
  param,
768
+ update_to_precond,
656
769
  Q,
657
- exprs,
658
770
  Q_cache,
659
- cache_expr: str,
771
+ velocity: Optional[List[Tensor]],
772
+ running_lower_bound: List[Tensor],
660
773
  cached: bool = False,
661
774
  prob: Optional[callable] = None,
662
775
  ):
663
- update = update.to(memory_format=torch.contiguous_format)
664
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
665
- Q_mat = _update_psgd_precond(
666
- cached,
667
- Q_cache,
668
- group,
669
- param,
670
- update if group["momentum_into_precond_update"] else grad,
671
- Q_mat,
672
- Q,
673
- exprs,
674
- prob,
675
- )
676
- return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
776
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
777
+ return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
677
778
 
678
779
 
679
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
780
+ @PrecondGradAccumGuard
781
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
680
782
  @no_state_no_foreach
681
783
  def scale_by_delayed_psgd(
682
784
  group,
683
785
  update,
684
786
  grad,
685
787
  param,
788
+ update_to_precond,
686
789
  Q,
687
- exprs,
688
790
  Q_cache,
689
- cache_expr: str,
791
+ velocity: Optional[List[Tensor]],
792
+ running_lower_bound: List[Tensor],
690
793
  cached: bool = False,
691
794
  prob: Optional[callable] = None,
692
795
  ):
693
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
694
- precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
695
- _ = _update_psgd_precond(
696
- cached,
697
- Q_cache,
698
- group,
699
- param,
700
- update if group["momentum_into_precond_update"] else grad,
701
- Q_mat,
702
- Q,
703
- exprs,
704
- prob,
705
- )
706
- return precond
796
+ if group.get("inverse_free", False):
797
+ precond = None
798
+ else:
799
+ precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
800
+ new = _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
801
+ return new if precond is None else precond
707
802
 
708
803
 
709
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
804
+ @PrecondGradAccumGuard
805
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
710
806
  @no_state_no_foreach
711
807
  def update_by_psgd(
712
808
  group,
713
809
  update,
714
810
  grad,
715
811
  param,
812
+ update_to_precond,
716
813
  Q,
717
- exprs,
718
814
  Q_cache,
719
- cache_expr: str,
815
+ velocity: Optional[List[Tensor]],
816
+ running_lower_bound: List[Tensor],
720
817
  cached: bool = False,
721
818
  prob: Optional[callable] = None,
722
819
  ):
723
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
724
- Q_mat = _update_psgd_precond(
725
- cached,
726
- Q_cache,
727
- group,
728
- param,
729
- update if group["momentum_into_precond_update"] else grad,
730
- Q_mat,
731
- Q,
732
- exprs,
733
- prob,
734
- )
735
- _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
736
- raise SkipUpdate
820
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
821
+ _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
822
+ raise SkipUpdate from None
737
823
 
738
824
 
739
825
  @no_state
@@ -741,34 +827,31 @@ def sign(group, update, grad, param, graft: bool = True):
741
827
  return utils.sign_(update, graft)
742
828
 
743
829
 
744
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
830
+ @no_state
831
+ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
832
+ assert clip_fn is not None
833
+ return clip_fn(update)
834
+
835
+
836
+ @PrecondGradAccumGuard
837
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
745
838
  @no_state_no_foreach
746
839
  def update_by_delayed_psgd(
747
840
  group,
748
841
  update,
749
842
  grad,
750
843
  param,
844
+ update_to_precond,
751
845
  Q,
752
- exprs,
753
846
  Q_cache,
754
- cache_expr: str,
847
+ velocity: Optional[List[Tensor]],
848
+ running_lower_bound: List[Tensor],
755
849
  cached: bool = False,
756
850
  prob: Optional[callable] = None,
757
851
  ):
758
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
759
- _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
760
- _ = _update_psgd_precond(
761
- cached,
762
- Q_cache,
763
- group,
764
- param,
765
- update if group["momentum_into_precond_update"] else grad,
766
- Q_mat,
767
- Q,
768
- exprs,
769
- prob,
770
- )
771
- raise SkipUpdate
852
+ _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
853
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
854
+ raise SkipUpdate from None
772
855
 
773
856
 
774
857
  def palm_beta2(state, group, update, grad, param):
@@ -819,12 +902,54 @@ def create_branch(branches: List[List[callable]], merge_fn: callable):
819
902
  return _branch
820
903
 
821
904
 
905
+ def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
906
+ if retain:
907
+ if offset:
908
+ raise ValueError("offset cannot be retained")
909
+
910
+ offset = -1
911
+ for fn in fns:
912
+ while isinstance(fn, (FunctionTransform, functools.partial)):
913
+ if isinstance(fn, functools.partial):
914
+ fn = fn.func
915
+ continue
916
+ if fn.transform_idx is not None:
917
+ offset = max(offset, fn.transform_idx)
918
+ fn = fn.fn
919
+ offset += 1 # if we found nothing, this will be 0. if we found something, we START at N+1
920
+
921
+ fns = [copy.deepcopy(fn) for fn in fns]
922
+ for fn in fns:
923
+ while isinstance(fn, (FunctionTransform, functools.partial)):
924
+ if isinstance(fn, functools.partial):
925
+ fn = fn.func
926
+ continue
927
+ if not retain or fn.transform_idx is None:
928
+ fn.transform_idx = offset
929
+ offset += 1
930
+ fn = fn.fn
931
+ return fns
932
+
933
+
822
934
  class ChainOpt(utils.StatefulOptimizer):
823
935
  promote: bool = False
824
936
 
825
937
  def __init__(self, params, defaults, foreach: bool, *fns):
938
+ defaults = {k: v for k, v in defaults.items() if v is not use_default}
826
939
  super().__init__(params, defaults, foreach)
827
- self.fns = tuple(fns)
940
+ self.fns = fns
941
+
942
+ @property
943
+ def fns(self):
944
+ return self._fns
945
+
946
+ @fns.setter
947
+ def fns(self, value):
948
+ self._fns = value
949
+ self._set_indices(retain=True)
950
+
951
+ def _set_indices(self, retain=True):
952
+ self._fns = set_indices(self.fns, retain)
828
953
 
829
954
  def _step(self, group):
830
955
  if "base_lr" not in group:
@@ -868,7 +993,6 @@ class ChainOpt(utils.StatefulOptimizer):
868
993
  group["step"] = None
869
994
 
870
995
 
871
- use_default = object()
872
996
  str_or_fn = Union[str, callable, None, Literal[use_default]]
873
997
 
874
998
 
@@ -1020,3 +1144,27 @@ class ScheduleFree(BaseOpt):
1020
1144
  p32 = utils.promote(p.data)
1021
1145
  p32.lerp_(end=z, weight=1 - beta1)
1022
1146
  utils.copy_stochastic_(p.data, p32)
1147
+
1148
+
1149
+ class MSAM(BaseOpt):
1150
+ def eval(self):
1151
+ for group in self.param_groups:
1152
+ group["train_mode"] = train_mode = not group.get("train_mode")
1153
+ if not train_mode:
1154
+ for p in group["params"]:
1155
+ state = self.state_(p)
1156
+ if "z" in state:
1157
+ p_copy = p.data.clone()
1158
+ utils.copy_stochastic_(p.data, state["z"])
1159
+ utils.copy_stochastic_(state["z"], p_copy)
1160
+
1161
+ def train(self):
1162
+ for group in self.param_groups:
1163
+ group["train_mode"] = train_mode = not group.get("train_mode")
1164
+ if train_mode:
1165
+ for p in group["params"]:
1166
+ state = self.state_(p)
1167
+ if "z" in state:
1168
+ p_copy = p.data.clone()
1169
+ utils.copy_stochastic_(p.data, state["z"])
1170
+ utils.copy_stochastic_(state["z"], p_copy)