heavyball 1.7.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,14 +1,16 @@
1
+ import copy
1
2
  import functools
2
3
  import math
3
4
  import random
4
- from typing import List, Literal, Optional, Union
5
+ from collections.abc import Iterable as _Iterable
6
+ from typing import Iterable, List, Literal, Optional, Union
5
7
 
6
8
  import torch
7
9
  from torch import Tensor
8
10
 
9
11
  from . import utils
10
12
 
11
- balance_probability: float = 0.01
13
+ use_default = utils.use_default
12
14
 
13
15
 
14
16
  def _key_in_state(state, key):
@@ -36,20 +38,68 @@ def _guard_in_state(state, key, template_fn):
36
38
 
37
39
 
38
40
  class FunctionTransform:
39
- def __init__(self, fn):
41
+ def __init__(self, fn, names: list[str] | None = None):
42
+ if names is None:
43
+ names = []
40
44
  self.fn = fn
41
45
  self.fn_name = self.get_fn().__name__
46
+ self.transform_idx = None
47
+ self.is_initialized = False
48
+ self.names = names
42
49
 
43
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
50
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
51
+ raise NotImplementedError
52
+
53
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
44
54
  raise NotImplementedError
45
55
 
56
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
57
+ states = [state(p) for p in param]
58
+ skip_update = False
59
+ for st, a in zip(states, zip(update, grad, param, *args)):
60
+ if self.transform_idx not in st.get("is_initialized", set()):
61
+ try:
62
+ self._init(st, group, *a, **kwargs)
63
+ except SkipUpdate:
64
+ skip_update = True
65
+ except:
66
+ raise
67
+ finally:
68
+ if "is_initialized" not in st:
69
+ st["is_initialized"] = set()
70
+ st["is_initialized"].add(self.transform_idx)
71
+ if skip_update:
72
+ raise SkipUpdate from None
73
+ vars = [[st.get(self.val_name(name), None) for st in states] for name in self.names]
74
+ return self._call(state, group, update, grad, param, vars, *args, **kwargs)
75
+
46
76
  def get_fn(self):
47
77
  if utils.hasattr_none(self.fn, "get_fn"):
48
78
  return self.fn.get_fn()
49
79
  return self.fn
50
80
 
51
81
  def val_name(self, name):
52
- return f"{self.fn_name}_{name}"
82
+ assert self.transform_idx is not None
83
+ return f"{self.fn_name}_{name}_{self.transform_idx}"
84
+
85
+ def __repr__(self):
86
+ return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
87
+
88
+
89
+ class Branch:
90
+ def __init__(self, branches: List[List[callable]], merge_fn: callable):
91
+ self.branches = branches
92
+ self.merge_fn = merge_fn
93
+
94
+ def __call__(self, state, group, update, grad, param):
95
+ outputs = []
96
+ for branch in self.branches:
97
+ branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
98
+ branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
99
+ if skip_update:
100
+ raise ValueError("Branches should not skip updates")
101
+ outputs.append(branch_update)
102
+ return self.merge_fn(outputs)
53
103
 
54
104
 
55
105
  def _zero_guard(state, key, ref, dtype):
@@ -63,49 +113,102 @@ def _storage_dtype(group):
63
113
 
64
114
  class ZeroGuard(FunctionTransform):
65
115
  def __init__(self, fn, names):
66
- super().__init__(fn)
67
- self.names = names
116
+ super().__init__(fn, names)
68
117
 
69
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
70
- vars = [
71
- [_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
72
- for name in self.names
73
- ]
118
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
119
+ for name in self.names:
120
+ _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
121
+
122
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
74
123
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
75
124
 
76
125
 
126
+ class PrecondGradAccumGuard(FunctionTransform):
127
+ def __init__(self, fn):
128
+ super().__init__(fn, ["precond_grad_accum"])
129
+ self.steps_taken = 0
130
+ self.pass_through = None
131
+
132
+ def _accum(self, state, new):
133
+ self.steps_taken += 1
134
+ utils.stochastic_add_(state, new)
135
+
136
+ def _reset(self, state):
137
+ if self.steps_taken != 0:
138
+ self.steps_taken = 0
139
+ utils.zero_(state)
140
+
141
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
142
+ if self.pass_through is None:
143
+ self.pass_through = not group.get("precond_grad_accum", False)
144
+ if self.pass_through is False:
145
+ for name in self.names:
146
+ _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
147
+
148
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
149
+ base_grad = update if group.get("momentum_into_precond_update", True) else grad
150
+ if self.pass_through:
151
+ return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs)
152
+
153
+ (vars,) = vars
154
+ if group["is_preconditioning"]:
155
+ if self.steps_taken:
156
+ self._accum(vars, base_grad)
157
+ utils.stochastic_multiply_(vars, 1 / self.steps_taken)
158
+ else:
159
+ vars = base_grad
160
+ else:
161
+ self._accum(vars, base_grad)
162
+ vars = base_grad
163
+ try:
164
+ out = self.fn(state, group, update, grad, param, *args, vars, **kwargs)
165
+ finally:
166
+ if group["is_preconditioning"]:
167
+ self._reset(vars)
168
+
169
+ return out
170
+
171
+
77
172
  class CopyGuard(FunctionTransform):
78
173
  def __init__(self, fn, index, names):
79
- super().__init__(fn)
174
+ super().__init__(fn, names)
80
175
  self.index = index
81
- self.names = names
82
176
 
83
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
177
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
84
178
  val = [update, grad, param, *args][self.index]
85
- vars = [
86
- [_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
87
- for name in self.names
88
- ]
179
+ for name in self.names:
180
+ state[self.val_name(name)] = torch.clone(val)
181
+
182
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
89
183
  return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
90
184
 
91
185
 
92
- class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
186
+ class GeneralGuard(FunctionTransform):
93
187
  def __init__(self, fn, names, init_fn, skip_first: bool = True):
94
- super().__init__(fn)
95
- self.names = names
188
+ super().__init__(fn, names)
96
189
  self.init_fn = init_fn
97
190
  self.skip_first = skip_first
98
-
99
- def __call__(self, state, group, update, grad, param, *args, **kwargs):
100
- vars = []
101
- skip_update = False
102
- for p, g, u in zip(param, grad, update):
103
- st = state(p)
104
- skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
105
- vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
106
- if skip_update and self.skip_first:
107
- raise SkipUpdate
108
- return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
191
+ self.named_to_anonymous = None
192
+ self.anonymous_to_named = None
193
+
194
+ def _map(self, state_fn, param, mapping):
195
+ for p in param:
196
+ state = state_fn(p)
197
+ for name, mapped in mapping.items():
198
+ if mapped in state:
199
+ raise ValueError(f"Name {name} already mapped to {mapped}")
200
+ if name in state:
201
+ state[mapped] = state.pop(name)
202
+
203
+ def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
204
+ self.init_fn(state, group, update, grad, param, **kwargs)
205
+ for name in self.names:
206
+ state[self.val_name(name)] = state.pop(name, None)
207
+ if self.skip_first:
208
+ raise SkipUpdate from None
209
+
210
+ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
211
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
109
212
 
110
213
 
111
214
  class NoState(FunctionTransform):
@@ -124,10 +227,27 @@ class NoStateNoForeach(FunctionTransform):
124
227
  skip_update = True
125
228
  pass
126
229
  if skip_update:
127
- raise SkipUpdate
230
+ raise SkipUpdate from None
128
231
  return updates
129
232
 
130
233
 
234
+ class SqueezeGrad(FunctionTransform):
235
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
236
+ original_shapes = [u.shape for u in update]
237
+ update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update]
238
+ grad = [x.view_as(u) for x, u in zip(grad, update)]
239
+ param = [x.view_as(u) for x, u in zip(param, update)]
240
+ args = list(args)
241
+ for i, a in enumerate(args):
242
+ if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
243
+ args[i] = [x.view_as(u) for x, u in zip(a, update)]
244
+ for k, a in kwargs.items():
245
+ if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
246
+ kwargs[k] = [x.view_as(u) for x, u in zip(a, update)]
247
+ out = self.fn(state, group, update, grad, param, *args, **kwargs)
248
+ return [o.view(s) for o, s in zip(out, original_shapes)]
249
+
250
+
131
251
  def zero_guard(*names):
132
252
  return functools.partial(ZeroGuard, names=names)
133
253
 
@@ -158,12 +278,23 @@ def exp_avg(group, update, grad, param, exp_avg):
158
278
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
159
279
 
160
280
 
281
+ @copy_guard(2, "init")
282
+ @no_state
283
+ def weight_decay_to_init(group, update, grad, param, init):
284
+ utils.stochastic_lerp_(param, init, group["weight_decay_to_ema"] * group["lr"])
285
+ return update
286
+
287
+
288
+ def identity(state, group, update, grad, param):
289
+ return update
290
+
291
+
161
292
  @zero_guard("exp_avg")
162
293
  @no_state
163
294
  def weight_decay_to_ema(group, update, grad, param, exp_avg):
164
295
  utils.weight_decay_to_ema_(
296
+ param,
165
297
  exp_avg,
166
- update,
167
298
  utils.beta_debias(group["ema_beta"], group["step"]),
168
299
  group["weight_decay_to_ema"] * group["lr"],
169
300
  )
@@ -174,8 +305,8 @@ def weight_decay_to_ema(group, update, grad, param, exp_avg):
174
305
  @no_state
175
306
  def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
176
307
  utils.l1_weight_decay_to_ema_(
308
+ param,
177
309
  exp_avg,
178
- update,
179
310
  utils.beta_debias(group["ema_beta"], group["step"]),
180
311
  group["weight_decay_to_ema"] * group["lr"],
181
312
  )
@@ -221,7 +352,27 @@ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
221
352
  group["weight_decay"],
222
353
  group["caution"],
223
354
  )
224
- raise SkipUpdate
355
+ raise SkipUpdate from None
356
+
357
+
358
+ @zero_guard("exp_avg", "exp_avg_sq")
359
+ @no_state
360
+ def update_by_adamc(group, update, grad, param, exp_avg, exp_avg_sq):
361
+ utils.fused_adam_(
362
+ param,
363
+ exp_avg,
364
+ exp_avg_sq,
365
+ update,
366
+ grad,
367
+ utils.get_beta1(group),
368
+ utils.get_beta2(group),
369
+ group["step"],
370
+ group["lr"],
371
+ group["eps"],
372
+ group["lr"] * group["weight_decay"] / group["max_lr"],
373
+ group["caution"],
374
+ )
375
+ raise SkipUpdate from None
225
376
 
226
377
 
227
378
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -246,7 +397,7 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
246
397
  group["weight_decay"],
247
398
  group["caution"],
248
399
  )
249
- raise SkipUpdate
400
+ raise SkipUpdate from None
250
401
 
251
402
 
252
403
  @no_state
@@ -271,7 +422,26 @@ def update_by_schedule_free(group, update, grad, param, z):
271
422
  group["step"],
272
423
  group["weight_decay"],
273
424
  )
274
- raise SkipUpdate
425
+ raise SkipUpdate from None
426
+
427
+
428
+ @copy_guard(2, "z")
429
+ @zero_guard("exp_avg")
430
+ @no_state
431
+ def update_by_msam(group, update, grad, param, z, exp_avg):
432
+ utils.msam_(
433
+ group["lr"],
434
+ utils.beta_debias(utils.get_beta1(group), group["step"]),
435
+ param,
436
+ z,
437
+ update,
438
+ grad,
439
+ exp_avg,
440
+ group["caution"],
441
+ group["weight_decay"],
442
+ group["sam_step_size"],
443
+ )
444
+ raise SkipUpdate from None
275
445
 
276
446
 
277
447
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -279,7 +449,7 @@ def update_by_schedule_free(group, update, grad, param, z):
279
449
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
280
450
  if group["step"] == 1:
281
451
  utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
282
- raise SkipUpdate
452
+ raise SkipUpdate from None
283
453
 
284
454
  if group["step"] == 2:
285
455
  update = utils.promote(update)
@@ -291,7 +461,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
291
461
  utils.beta_debias(utils.get_beta2(group), group["step"]),
292
462
  group["eps"],
293
463
  )
294
- raise SkipUpdate
464
+ raise SkipUpdate from None
295
465
 
296
466
  utils.fused_adopt_(
297
467
  param,
@@ -307,7 +477,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
307
477
  group["weight_decay"],
308
478
  group["caution"],
309
479
  )
310
- raise SkipUpdate
480
+ raise SkipUpdate from None
311
481
 
312
482
 
313
483
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -315,7 +485,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
315
485
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
316
486
  if group["step"] == 1:
317
487
  utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
318
- raise SkipUpdate
488
+ raise SkipUpdate from None
319
489
 
320
490
  if group["step"] == 2:
321
491
  update = utils.promote(update)
@@ -327,7 +497,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
327
497
  utils.beta_debias(utils.get_beta2(group), group["step"]),
328
498
  group["eps"],
329
499
  )
330
- raise SkipUpdate
500
+ raise SkipUpdate from None
331
501
 
332
502
  return utils.adopt(
333
503
  update,
@@ -344,10 +514,11 @@ def _init_soap(state, group, update, grad, param, inner: str = ""):
344
514
 
345
515
 
346
516
  def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
347
- Q, state["exprs"] = utils.init_Q_exprs(
517
+ Q = utils.init_Q_exprs(
348
518
  grad,
349
519
  group["precond_init_scale"],
350
520
  group["precond_init_scale_scale"],
521
+ group["precond_init_scale_power"],
351
522
  group["max_size_triangular"],
352
523
  group["min_ndim_triangular"],
353
524
  group["memory_save_mode"],
@@ -356,32 +527,28 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
356
527
  dtype=getattr(torch, group["q_dtype"]),
357
528
  )
358
529
  state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
359
-
530
+ state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
531
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
532
+ if group["adaptive"]:
533
+ state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
360
534
  if not cached:
361
535
  return
362
536
 
363
537
  state["Q_cache"] = [torch.empty_like(q) for q in Q]
364
538
 
365
- expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)]
366
- expr = ",".join(expr)
367
- grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape))
368
- out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
369
- expr = f"{expr},{grad_expr}->{out_expr}"
370
-
371
- state["cache_expr"] = expr
372
-
373
539
 
374
540
  def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
375
541
  state["U"], state["V"], state["d"] = utils.init_lra(
376
542
  grad,
543
+ group["param_count"],
377
544
  group["precond_init_scale"],
378
545
  group["precond_init_scale_scale"],
546
+ group["precond_init_scale_power"],
379
547
  group["rank"],
380
548
  getattr(param, "hessian_vector", None),
381
549
  getattr(param, "vector", None),
382
550
  dtype=getattr(torch, group["q_dtype"]),
383
551
  )
384
- group["preconditioning_step"] = 0
385
552
 
386
553
 
387
554
  def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
@@ -402,12 +569,12 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
402
569
 
403
570
  @no_state_no_foreach
404
571
  def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
405
- if update.dim() == 1:
572
+ if update.dim() < 2:
406
573
  return update
407
574
  original_shape = update.shape
408
575
  # doing it this way, as tmp and update are not guaranteed to share memory address or layout
409
576
  tmp = update.flatten(1, -1)
410
- utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp, scale_mode)
577
+ utils.inplace_orthogonal_(tmp, out=tmp, scale_mode=scale_mode)
411
578
  return tmp.reshape(original_shape)
412
579
 
413
580
 
@@ -424,7 +591,7 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
424
591
 
425
592
 
426
593
  def _store_std(state, group, update, grad, param):
427
- state["init_std"] = torch.std(grad, dim=0)
594
+ state["init_std"] = torch.std(param)
428
595
 
429
596
 
430
597
  @general_guard("init_std", init_fn=_store_std, skip_first=False)
@@ -483,9 +650,7 @@ _optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_}
483
650
  @general_guard("Q", "GG", init_fn=_init_soap)
484
651
  @no_state
485
652
  def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"):
486
- update = utils.promote(update) # Promote to highest precision if needed
487
-
488
- grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
653
+ grad_projected = [utils.project(utils.promote(u), q, False) for u, q in zip(update, Q)]
489
654
  fn = _optim_fns[inner]
490
655
  precond = fn(
491
656
  exp_avg,
@@ -500,7 +665,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
500
665
 
501
666
  for u, q, gg, ea in zip(update, Q, GG, exp_avg):
502
667
  utils.update_preconditioner(
503
- u,
668
+ utils.promote(u),
504
669
  q,
505
670
  gg,
506
671
  ea,
@@ -512,35 +677,38 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
512
677
  return precond
513
678
 
514
679
 
515
- def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
680
+ def _update_psgd_precond(
681
+ cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, step, prob: Optional[callable] = None
682
+ ) -> Optional[Tensor]:
516
683
  if prob is None:
517
684
  prob = utils.precond_update_prob_schedule()
518
685
 
519
686
  if not group["is_preconditioning"]:
520
- return Q_mat
687
+ return
521
688
 
522
689
  if utils.hasattr_none(param, "vector"):
523
690
  vector, hessian_vector = param.vector, param.hessian_vector
524
691
  del param.vector
525
692
  del param.hessian_vector
693
+ elif group["inverse_free"]:
694
+ vector, hessian_vector = None, grad
526
695
  else:
527
- vector, hessian_vector = utils.dampen_grad(grad)
696
+ vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
528
697
 
529
- utils.psgd_update_precond(
530
- Q_mat,
531
- exprs,
698
+ precond = utils.psgd_update_precond(
532
699
  hessian_vector,
533
700
  group["precond_lr"],
534
701
  Q,
535
702
  group["store_triu_as_line"],
703
+ velocity,
704
+ utils.get_beta2(group),
705
+ group["ortho_method"],
536
706
  vector,
707
+ running_lower_bound,
708
+ group["lower_bound_beta"],
709
+ group["precond_update_power_iterations"],
537
710
  )
538
-
539
- if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
540
- if group["store_triu_as_line"]:
541
- utils.psgd_balance_Q([q_ for _, q_ in Q])
542
- else:
543
- utils.psgd_balance_Q(Q)
711
+ del vector, hessian_vector
544
712
 
545
713
  if isinstance(prob, float):
546
714
  float_prob = prob
@@ -548,54 +716,45 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
548
716
  float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1))
549
717
  group["is_cached"] = should_use_cache = cached and float_prob < 0.5
550
718
 
551
- if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
552
- return _update_psgd_cache(cached, Q_cache, Q_mat)
553
- return Q_mat
554
-
719
+ if precond is not None:
720
+ return precond
721
+ if not should_use_cache or not cached:
722
+ return None # caching adds extra ops and is not worth the overhead when we precondition at every step
555
723
 
556
- def _update_psgd_cache(cached, Q_cache, q):
557
- if not cached:
558
- return q
559
-
560
- for c_, q_ in zip(Q_cache, q):
724
+ for c_, q_ in zip(Q_cache, utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q):
561
725
  if q_.ndim == 2:
562
726
  torch.matmul(q_.T, q_, out=c_)
563
727
  else:
564
728
  torch.mul(q_, q_, out=c_)
565
- return Q_cache
729
+ return None
566
730
 
567
731
 
568
- def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
732
+ def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
733
+ kwargs = {"ea": update, "caution": group["caution"], "grad": grad}
569
734
  if group.get("is_cached", False):
570
- out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad)
735
+ out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
571
736
  else:
572
- out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad)
737
+ out = utils.psgd_precond_grad(
738
+ preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
739
+ )
573
740
  group["caution"] = False # we already cautioned here - shouldn't do it again
574
741
  return out
575
742
 
576
743
 
577
- def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
744
+ def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
745
+ kwargs = {
746
+ "ea": update,
747
+ "caution": group["caution"],
748
+ "grad": grad,
749
+ "param": param,
750
+ "lr": group["lr"],
751
+ "decay": group["weight_decay"],
752
+ }
578
753
  if group.get("is_cached", False):
579
- utils.fused_precond_grad_cached_(
580
- cache_expr,
581
- update,
582
- param,
583
- group["lr"],
584
- grad,
585
- group["weight_decay"],
586
- group["caution"],
587
- *Q_cache,
588
- )
754
+ utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
589
755
  else:
590
756
  utils.fused_psgd_precond_grad(
591
- exprs[-1],
592
- update,
593
- param,
594
- group["lr"],
595
- grad,
596
- group["weight_decay"],
597
- group["caution"],
598
- *Q_mat,
757
+ preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
599
758
  )
600
759
 
601
760
 
@@ -603,7 +762,7 @@ def _update_lra(
603
762
  group, U: List[Tensor], V: List[Tensor], d: List[Tensor], params: List[Tensor], grads: List[Tensor], delayed: bool
604
763
  ):
605
764
  if not group["is_preconditioning"]:
606
- return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
765
+ return utils.multi_flatten((U, 1), (V, 1), (d, 0))
607
766
 
608
767
  if utils.hasattr_none(params[0], "hessian_vector"):
609
768
  vector = utils.flatten([p.vector for p in params])
@@ -613,127 +772,121 @@ def _update_lra(
613
772
  del p.hessian_vector
614
773
  else:
615
774
  vector, hessian_vector = utils.dampen_multiple(grads)
616
- return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
775
+ precond_step = group["precond_step"] = group.get("precond_step", -1) + 1
776
+ return utils.update_lra_precond_(
777
+ U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed, bool(precond_step % 2)
778
+ )
617
779
 
618
780
 
781
+ @SqueezeGrad
782
+ @PrecondGradAccumGuard
619
783
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
620
784
  @no_state
621
- def scale_by_psgd_lra(group, update, grad, param, U, V, d):
622
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
785
+ def scale_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
786
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
623
787
  return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
624
788
 
625
789
 
790
+ @SqueezeGrad
791
+ @PrecondGradAccumGuard
626
792
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
627
793
  @no_state
628
- def update_by_psgd_lra(group, update, grad, param, U, V, d):
629
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, False)
630
- utils.apply_lra_update(param, update, u, v, d)
631
- raise SkipUpdate
794
+ def update_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
795
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, False)
796
+ utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
797
+ raise SkipUpdate from None
632
798
 
633
799
 
800
+ @SqueezeGrad
801
+ @PrecondGradAccumGuard
634
802
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
635
803
  @no_state
636
- def scale_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
637
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
804
+ def scale_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
805
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
638
806
  return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
639
807
 
640
808
 
809
+ @SqueezeGrad
810
+ @PrecondGradAccumGuard
641
811
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
642
812
  @no_state
643
- def update_by_delayed_psgd_lra(group, update, grad, param, U, V, d):
644
- u, v, d = _update_lra(group, U, V, d, param, update if group["momentum_into_precond_update"] else grad, True)
645
- utils.apply_lra_update(param, update, u, v, d)
646
- raise SkipUpdate
813
+ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
814
+ u, v, d = _update_lra(group, U, V, d, param, update_to_precond, True)
815
+ utils.apply_lra_update(param, update, u, v, d, group["lr"], group["weight_decay"], group["caution"], grad)
816
+ raise SkipUpdate from None
647
817
 
648
818
 
649
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
819
+ @SqueezeGrad
820
+ @PrecondGradAccumGuard
821
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
650
822
  @no_state_no_foreach
651
823
  def scale_by_psgd(
652
824
  group,
653
825
  update,
654
826
  grad,
655
827
  param,
828
+ update_to_precond,
656
829
  Q,
657
- exprs,
658
830
  Q_cache,
659
- cache_expr: str,
831
+ velocity: Optional[List[Tensor]],
832
+ running_lower_bound: List[Tensor],
833
+ step: Tensor,
660
834
  cached: bool = False,
661
835
  prob: Optional[callable] = None,
662
836
  ):
663
- update = update.to(memory_format=torch.contiguous_format)
664
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
665
- Q_mat = _update_psgd_precond(
666
- cached,
667
- Q_cache,
668
- group,
669
- param,
670
- update if group["momentum_into_precond_update"] else grad,
671
- Q_mat,
672
- Q,
673
- exprs,
674
- prob,
675
- )
676
- return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
837
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
838
+ return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
677
839
 
678
840
 
679
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
841
+ @SqueezeGrad
842
+ @PrecondGradAccumGuard
843
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
680
844
  @no_state_no_foreach
681
845
  def scale_by_delayed_psgd(
682
846
  group,
683
847
  update,
684
848
  grad,
685
849
  param,
850
+ update_to_precond,
686
851
  Q,
687
- exprs,
688
852
  Q_cache,
689
- cache_expr: str,
853
+ velocity: Optional[List[Tensor]],
854
+ running_lower_bound: List[Tensor],
855
+ step: Tensor,
690
856
  cached: bool = False,
691
857
  prob: Optional[callable] = None,
692
858
  ):
693
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
694
- precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
695
- _ = _update_psgd_precond(
696
- cached,
697
- Q_cache,
698
- group,
699
- param,
700
- update if group["momentum_into_precond_update"] else grad,
701
- Q_mat,
702
- Q,
703
- exprs,
704
- prob,
859
+ if group.get("inverse_free", False):
860
+ precond = None
861
+ else:
862
+ precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
863
+ new = _update_psgd_precond(
864
+ cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob
705
865
  )
706
- return precond
866
+ return new if precond is None else precond
707
867
 
708
868
 
709
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
869
+ @SqueezeGrad
870
+ @PrecondGradAccumGuard
871
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
710
872
  @no_state_no_foreach
711
873
  def update_by_psgd(
712
874
  group,
713
875
  update,
714
876
  grad,
715
877
  param,
878
+ update_to_precond,
716
879
  Q,
717
- exprs,
718
880
  Q_cache,
719
- cache_expr: str,
881
+ velocity: Optional[List[Tensor]],
882
+ running_lower_bound: List[Tensor],
883
+ step: Tensor,
720
884
  cached: bool = False,
721
885
  prob: Optional[callable] = None,
722
886
  ):
723
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
724
- Q_mat = _update_psgd_precond(
725
- cached,
726
- Q_cache,
727
- group,
728
- param,
729
- update if group["momentum_into_precond_update"] else grad,
730
- Q_mat,
731
- Q,
732
- exprs,
733
- prob,
734
- )
735
- _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
736
- raise SkipUpdate
887
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
888
+ _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
889
+ raise SkipUpdate from None
737
890
 
738
891
 
739
892
  @no_state
@@ -741,34 +894,33 @@ def sign(group, update, grad, param, graft: bool = True):
741
894
  return utils.sign_(update, graft)
742
895
 
743
896
 
744
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd_kron, skip_first=False)
897
+ @no_state
898
+ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
899
+ assert clip_fn is not None
900
+ return clip_fn(update)
901
+
902
+
903
+ @SqueezeGrad
904
+ @PrecondGradAccumGuard
905
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
745
906
  @no_state_no_foreach
746
907
  def update_by_delayed_psgd(
747
908
  group,
748
909
  update,
749
910
  grad,
750
911
  param,
912
+ update_to_precond,
751
913
  Q,
752
- exprs,
753
914
  Q_cache,
754
- cache_expr: str,
915
+ velocity: Optional[List[Tensor]],
916
+ running_lower_bound: List[Tensor],
917
+ step: Tensor,
755
918
  cached: bool = False,
756
919
  prob: Optional[callable] = None,
757
920
  ):
758
- Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
759
- _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
760
- _ = _update_psgd_precond(
761
- cached,
762
- Q_cache,
763
- group,
764
- param,
765
- update if group["momentum_into_precond_update"] else grad,
766
- Q_mat,
767
- Q,
768
- exprs,
769
- prob,
770
- )
771
- raise SkipUpdate
921
+ _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
922
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
923
+ raise SkipUpdate from None
772
924
 
773
925
 
774
926
  def palm_beta2(state, group, update, grad, param):
@@ -805,26 +957,63 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
805
957
  utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
806
958
 
807
959
 
808
- def create_branch(branches: List[List[callable]], merge_fn: callable):
809
- def _branch(state, group, update, grad, param):
810
- outputs = []
811
- for branch in branches:
812
- branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
813
- branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
814
- if skip_update:
815
- raise ValueError("Branches should not skip updates")
816
- outputs.append(branch_update)
817
- return merge_fn(outputs)
960
+ def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
961
+ if retain and offset:
962
+ raise ValueError("offset cannot be retained")
818
963
 
819
- return _branch
964
+ def _walk(obj):
965
+ stack = [obj]
966
+ while stack:
967
+ cur = stack.pop()
968
+ if isinstance(cur, FunctionTransform):
969
+ yield cur
970
+ stack.append(cur.fn)
971
+ elif isinstance(cur, functools.partial):
972
+ stack.append(cur.func)
973
+ elif isinstance(cur, Branch):
974
+ for branch in cur.branches:
975
+ stack.extend(branch)
976
+ elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)):
977
+ stack.extend(cur)
978
+
979
+ if retain:
980
+ offset = max((ft.transform_idx for ft in _walk(fns) if ft.transform_idx is not None), default=-1) + 1
981
+
982
+ new_fns = [copy.deepcopy(fn) for fn in fns]
983
+ for ft in _walk(new_fns):
984
+ if not retain or ft.transform_idx is None:
985
+ ft.transform_idx, offset = offset, offset + 1
986
+
987
+ return new_fns
820
988
 
821
989
 
822
990
  class ChainOpt(utils.StatefulOptimizer):
823
991
  promote: bool = False
992
+ global_defaults = {
993
+ "caution": False,
994
+ "lr": 1,
995
+ "warmup_steps": 0,
996
+ "weight_decay": 0,
997
+ "eps": 1e-8,
998
+ }
824
999
 
825
1000
  def __init__(self, params, defaults, foreach: bool, *fns):
826
- super().__init__(params, defaults, foreach)
827
- self.fns = tuple(fns)
1001
+ base = self.global_defaults.copy()
1002
+ base.update({k: v for k, v in defaults.items() if v is not use_default})
1003
+ super().__init__(params, base, foreach)
1004
+ self.fns = fns
1005
+
1006
+ @property
1007
+ def fns(self):
1008
+ return self._fns
1009
+
1010
+ @fns.setter
1011
+ def fns(self, value):
1012
+ self._fns = value
1013
+ self._set_indices(retain=True)
1014
+
1015
+ def _set_indices(self, retain=True):
1016
+ self._fns = set_indices(self.fns, retain)
828
1017
 
829
1018
  def _step(self, group):
830
1019
  if "base_lr" not in group:
@@ -868,7 +1057,6 @@ class ChainOpt(utils.StatefulOptimizer):
868
1057
  group["step"] = None
869
1058
 
870
1059
 
871
- use_default = object()
872
1060
  str_or_fn = Union[str, callable, None, Literal[use_default]]
873
1061
 
874
1062
 
@@ -931,7 +1119,7 @@ class BaseOpt(ChainOpt):
931
1119
 
932
1120
  update_clipping: str_or_fn = None
933
1121
  The function to use for clipping the outgoing updates before applying them, after all other transformations.
934
- This will turn off
1122
+ This will turn off fused updates.
935
1123
  This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
936
1124
 
937
1125
  """
@@ -945,11 +1133,11 @@ class BaseOpt(ChainOpt):
945
1133
  self,
946
1134
  params,
947
1135
  defaults,
948
- foreach: bool,
949
- gradient_clipping: str_or_fn,
950
- update_clipping: str_or_fn,
1136
+ foreach: bool = True,
1137
+ gradient_clipping: str_or_fn = None,
1138
+ update_clipping: str_or_fn = None,
951
1139
  palm: bool = use_default,
952
- *fns,
1140
+ fns: Iterable[callable] = (),
953
1141
  compile_step: bool = use_default,
954
1142
  promote: bool = use_default,
955
1143
  ):
@@ -957,6 +1145,7 @@ class BaseOpt(ChainOpt):
957
1145
  raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
958
1146
 
959
1147
  args, kwargs = None, None
1148
+ fns = tuple(fns)
960
1149
  fn = fns[-1]
961
1150
  if isinstance(fn, functools.partial):
962
1151
  fn, args, kwargs = fn.func, fn.args, fn.keywords
@@ -1020,3 +1209,27 @@ class ScheduleFree(BaseOpt):
1020
1209
  p32 = utils.promote(p.data)
1021
1210
  p32.lerp_(end=z, weight=1 - beta1)
1022
1211
  utils.copy_stochastic_(p.data, p32)
1212
+
1213
+
1214
+ class MSAM(BaseOpt):
1215
+ def eval(self):
1216
+ for group in self.param_groups:
1217
+ group["train_mode"] = train_mode = not group.get("train_mode")
1218
+ if not train_mode:
1219
+ for p in group["params"]:
1220
+ state = self.state_(p)
1221
+ if "z" in state:
1222
+ p_copy = p.data.clone()
1223
+ utils.copy_stochastic_(p.data, state["z"])
1224
+ utils.copy_stochastic_(state["z"], p_copy)
1225
+
1226
+ def train(self):
1227
+ for group in self.param_groups:
1228
+ group["train_mode"] = train_mode = not group.get("train_mode")
1229
+ if train_mode:
1230
+ for p in group["params"]:
1231
+ state = self.state_(p)
1232
+ if "z" in state:
1233
+ p_copy = p.data.clone()
1234
+ utils.copy_stochastic_(p.data, state["z"])
1235
+ utils.copy_stochastic_(state["z"], p_copy)