heavyball 1.7.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,28 +1,29 @@
1
+ import collections
1
2
  import contextlib
2
3
  import functools
3
4
  import gc
4
5
  import inspect
5
6
  import math
7
+ import pickle
6
8
  import random
7
9
  import re
8
10
  import string
9
11
  import warnings
10
- from typing import Callable, List, Optional, Tuple, Union
12
+ from typing import Callable, List, Literal, Optional, Tuple, Union
11
13
 
12
14
  import numpy as np
13
15
  import torch
14
16
  from torch import Tensor
15
- from torch._dynamo import config
16
17
  from torch._dynamo.exc import TorchDynamoException
17
18
  from torch.backends import cudnn, opt_einsum
19
+ from torch.nn import functional as F
18
20
  from torch.utils._pytree import tree_map
19
21
 
20
- config.cache_size_limit = 2**16
21
-
22
22
  compile_mode = "max-autotune-no-cudagraphs"
23
23
  dynamic = False
24
24
  compile_mode_recommended_to_none = None
25
- zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
25
+ zeroth_power_mode = "newtonschulz"
26
+ precise_zeroth_power_mode = "qr" # or svd
26
27
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
28
  _cudnn_double_backward_pattern = re.compile(
28
29
  r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
@@ -68,6 +69,16 @@ def decorator_knowngood(func: Callable, fullgraph: bool = True):
68
69
  einsum_base = string.ascii_lowercase
69
70
 
70
71
 
72
+ @decorator_knowngood
73
+ def compiled_einsum(expr, *args):
74
+ """
75
+ this is necessary to avoid the slowdown introduced by uncompiled einsum
76
+ uncompiled einsum is twice as slow if we add three 1-sized dimensions
77
+ for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
78
+ """
79
+ return torch.einsum(expr, *args)
80
+
81
+
71
82
  @decorator_knowngood
72
83
  def _compilable_schedule_free_(
73
84
  p: List[Tensor],
@@ -122,6 +133,47 @@ def schedule_free_(
122
133
  return weight_sum
123
134
 
124
135
 
136
+ @decorator_knowngood
137
+ def _compilable_msam(
138
+ lr: Tensor,
139
+ beta1: Tensor,
140
+ param: List[Tensor],
141
+ z: List[Tensor],
142
+ update: List[Tensor],
143
+ grad: List[Tensor],
144
+ exp_avg: List[Tensor],
145
+ caution: bool,
146
+ decay: Tensor,
147
+ sam_step_size: Tensor,
148
+ ):
149
+ exp_avg32 = _lerp(exp_avg, update, beta1)
150
+ for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
151
+ u_ = u_.view_as(z_)
152
+ z32_ = promote(z_)
153
+ if caution:
154
+ u_ = _compilable_cautioning(promote(g_), u_)
155
+ z32_ = z32_ * (1 - decay * lr) + u_ * -lr
156
+ copy_stochastic_(z_, z32_)
157
+ copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
158
+
159
+
160
+ def msam_(
161
+ lr: float,
162
+ beta1: float,
163
+ param: List[Tensor],
164
+ z: List[Tensor],
165
+ update: List[Tensor],
166
+ grad: List[Tensor],
167
+ exp_avg: List[Tensor],
168
+ caution: bool,
169
+ weight_decay: float,
170
+ sam_step_size: float,
171
+ ):
172
+ param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
173
+ lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
174
+ _compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
175
+
176
+
125
177
  def append_or_extend(base, new):
126
178
  if isinstance(new, list):
127
179
  base.extend(new)
@@ -161,7 +213,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
161
213
  new_shape = [grad.shape[0], *new_shape[::-1]]
162
214
  new_grad = grad.reshape(new_shape)
163
215
  if not split:
164
- return new_grad
216
+ return new_grad.to(memory_format=torch.contiguous_format).contiguous()
165
217
 
166
218
  grads = [new_grad]
167
219
  for i, sh in reversed(list(enumerate(new_shape[:]))):
@@ -172,7 +224,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
172
224
  continue
173
225
  grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
174
226
  if len(grads) == 1:
175
- return new_grad
227
+ return new_grad.to(memory_format=torch.contiguous_format).contiguous()
176
228
  new_grads = []
177
229
  for g in grads:
178
230
  append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
@@ -189,14 +241,14 @@ def eps_sqrt(item, eps):
189
241
 
190
242
  @decorator_knowngood
191
243
  def _compilable_exp_avg_sq_(
192
- state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
244
+ state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
193
245
  ):
194
246
  g32 = promote(grad)
195
247
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
196
248
 
197
249
  denom = [eps_sqrt(d, eps) for d in s32]
198
250
 
199
- if out[0] is None:
251
+ if out is None or out[0] is None:
200
252
  return denom
201
253
 
202
254
  copy_stochastic_list_(out, denom)
@@ -265,8 +317,8 @@ def adaptive_gradient_clipping_(
265
317
  def is_compiling():
266
318
  try:
267
319
  return torch.compiler.is_compiling()
268
- except TorchDynamoException:
269
- return True
320
+ except (TorchDynamoException, AttributeError):
321
+ return False
270
322
 
271
323
 
272
324
  def set_(dst: Tensor, src: Tensor):
@@ -279,16 +331,29 @@ def clean():
279
331
 
280
332
 
281
333
  def _ignore_warning(msg):
282
- warnings.filterwarnings("ignore", f".*{msg}.*")
334
+ warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
283
335
 
284
336
 
285
- def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
337
+ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
338
+ import opt_einsum as _opt_einsum
339
+
286
340
  cudnn.benchmark = True
287
341
  cudnn.deterministic = False
288
342
  cudnn.benchmark_limit = benchmark_limit
289
343
  torch.use_deterministic_algorithms(False)
290
344
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
291
- opt_einsum.set_flags(True, einsum_strategy)
345
+ opt_einsum.set_flags(True)
346
+ if einsum_strategy == "heavyball":
347
+ opt_einsum.strategy = "auto-hq"
348
+ choices = _opt_einsum.paths._AUTO_HQ_CHOICES
349
+ for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
350
+ if isinstance(fn, int):
351
+ fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
352
+ for i in range(max(choices.keys()), max_val):
353
+ if i not in choices:
354
+ choices[i] = fn
355
+ else:
356
+ opt_einsum.strategy = einsum_strategy
292
357
 
293
358
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
294
359
  _ignore_warning(
@@ -297,32 +362,39 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
297
362
  _ignore_warning(
298
363
  "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
299
364
  )
365
+ _ignore_warning(
366
+ "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
367
+ )
300
368
 
301
369
 
302
- @decorator
370
+ @decorator_knowngood
303
371
  def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
304
- assert len(G.shape) == 2
305
- a, b, c = (3.4445, -4.7750, 2.0315)
306
- X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
307
- X /= X.norm() + eps # ensure top singular value <= 1
308
- if G.size(0) > G.size(1):
309
- X = X.T
310
- for _ in range(steps):
311
- A = X @ X.T
312
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
372
+ assert (
373
+ G.ndim >= 2
374
+ ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
375
+ assert steps == 5
376
+ X = G if G.dtype == torch.float64 else stochastic_round_(G)
377
+ if G.size(-2) > G.size(-1):
378
+ X = X.mT
379
+
380
+ stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
381
+ # Perform the NS iterations
382
+ for a, b, c in [
383
+ (4.0848, -6.8946, 2.9270),
384
+ (3.9505, -6.3029, 2.6377),
385
+ (3.7418, -5.5913, 2.3037),
386
+ (2.8769, -3.1427, 1.2046),
387
+ (2.8366, -3.0525, 1.2012),
388
+ ]:
389
+ A = X @ X.mT
390
+ B = (
391
+ b * A + c * A @ A
392
+ ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
313
393
  X = a * X + B @ X
314
- if G.size(0) > G.size(1):
315
- X = X.T
316
- return X.to(G.dtype)
317
-
318
394
 
319
- def ortho(x):
320
- if zeroth_power_mode == "qr":
321
- return torch.linalg.qr(x).Q
322
- if zeroth_power_mode == "svd":
323
- u, _s, v = torch.linalg.svd(x)
324
- return u @ v.T
325
- raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
395
+ if G.size(-2) > G.size(-1):
396
+ X = X.mT
397
+ return X.to(G.dtype)
326
398
 
327
399
 
328
400
  @decorator_knowngood
@@ -377,7 +449,7 @@ def _compilable_grafting(magnitude, direction):
377
449
 
378
450
 
379
451
  @decorator_knowngood
380
- def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
452
+ def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
381
453
  if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
382
454
  y = zeropower_via_newtonschulz5(x, 5)
383
455
  elif mode == "qr":
@@ -395,9 +467,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
395
467
  y = _compilable_grafting(x, y)
396
468
  else:
397
469
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
470
+ if out is None:
471
+ return y
472
+
398
473
  set_(out, y)
399
474
 
400
475
 
476
+ def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
477
+ return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
478
+
479
+
401
480
  @decorator_knowngood
402
481
  def _compilable_scatter_set(target, source, index):
403
482
  target[:] = source.contiguous()[index].reshape_as(target)
@@ -413,6 +492,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
413
492
  :param Q: List of current eigenbases (updated in-place to Q_new).
414
493
  :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
415
494
  """
495
+ if exp_avg.dim() == 0: # preconditioning doesn't make sense here
496
+ Q.clear()
497
+ return
498
+
416
499
  if isinstance(Q, list) and not Q:
417
500
  return
418
501
 
@@ -430,10 +513,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
430
513
  q_old = promote(q.data)
431
514
 
432
515
  tmp = m @ q_old
433
- est_eig = torch.einsum("ij,ij->j", q_old, tmp)
516
+ est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
434
517
  sort_idx = torch.argsort(est_eig, descending=True)
435
518
 
436
- tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
519
+ tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
437
520
  new_qs.append(tmp)
438
521
 
439
522
  if exp_avg is None:
@@ -453,7 +536,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
453
536
  out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
454
537
 
455
538
  subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
456
- exp_avg_new = torch.einsum(
539
+ exp_avg_new = compiled_einsum(
457
540
  subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
458
541
  )
459
542
  copy_stochastic_(exp_avg, exp_avg_new)
@@ -487,10 +570,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
487
570
  except torch.OutOfMemoryError:
488
571
  if m.device.type == "cpu":
489
572
  raise
490
- else:
573
+ if torch.cuda.is_available():
574
+ torch.cuda.synchronize(m.device)
575
+ clean()
576
+ m = m.cpu()
577
+ except RuntimeError as e:
578
+ if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
579
+ torch.cuda.synchronize(m.device)
580
+ clean()
491
581
  m = m.cpu()
492
- except RuntimeError: # failed to compute eigenvalues
493
- if m.dtype != torch.double:
582
+ elif m.dtype != torch.double:
494
583
  m = m.double()
495
584
  elif eps < max_eps:
496
585
  eps = eps ** (2 / 3)
@@ -568,6 +657,19 @@ def scalar_guard(*args):
568
657
  return out
569
658
 
570
659
 
660
+ def broadcastable_list_guard(*xs):
661
+ xs = list_guard(*xs)
662
+ for x in xs:
663
+ if isinstance(x[0], Tensor):
664
+ ref = x[0]
665
+ break
666
+ else:
667
+ raise ValueError("No tensor-valued input given")
668
+ xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
669
+ max_len = max(len(x) for x in xs)
670
+ return [x if len(x) > 1 else x * max_len for x in xs]
671
+
672
+
571
673
  @decorator_knowngood
572
674
  def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
573
675
  for x_, y_ in zip(x, y):
@@ -576,8 +678,8 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
576
678
  copy_stochastic_(x_, x32 + y32 * alpha)
577
679
 
578
680
 
579
- def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
580
- x, y = list_guard(x, y)
681
+ def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
682
+ x, y = broadcastable_list_guard(x, y)
581
683
  alpha = scalar_guard(alpha, x[0])
582
684
  _compilable_stochastic_add_(x, y, alpha)
583
685
 
@@ -590,8 +692,10 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
590
692
  copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
591
693
 
592
694
 
593
- def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
594
- x, y = list_guard(x, y)
695
+ def stochastic_add_divide_(
696
+ x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
697
+ ):
698
+ x, y = broadcastable_list_guard(x, y)
595
699
  alpha, divisor = scalar_guard(alpha, divisor, x[0])
596
700
  _compilable_stochastic_add_divide_(x, y, alpha, divisor)
597
701
 
@@ -604,8 +708,8 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
604
708
  copy_stochastic_(x_, x32 * y32)
605
709
 
606
710
 
607
- def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
608
- x, y = list_guard(x, y)
711
+ def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
712
+ x, y = broadcastable_list_guard(x, y)
609
713
  _compilable_stochastic_multiply_(x, y)
610
714
 
611
715
 
@@ -624,7 +728,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
624
728
  b = einsum_base[idx]
625
729
  g0 = einsum_base[: grad.dim()]
626
730
  g1 = g0.replace(b, b.upper())
627
- outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
731
+ outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
628
732
  stochastic_lerp_(m, outer_product, 1 - beta)
629
733
 
630
734
 
@@ -706,7 +810,7 @@ def project(grad, Q, back: bool):
706
810
  preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
707
811
  if preconditioners:
708
812
  out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
709
- out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
813
+ out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
710
814
  grad = out.to(grad.dtype)
711
815
  return grad
712
816
 
@@ -714,24 +818,28 @@ def project(grad, Q, back: bool):
714
818
  @contextlib.contextmanager
715
819
  def patch_backward():
716
820
  @contextlib.contextmanager
717
- def _inner(module):
821
+ def patch_module(module):
718
822
  original = module.backward
719
-
720
- signature = inspect.signature(original)
721
-
722
- def patched_backward(*args, **kwargs):
723
- new_kwargs = signature.bind(*args)
724
- new_kwargs.apply_defaults()
725
- new_kwargs = new_kwargs.arguments
726
- new_kwargs.update(kwargs)
727
- new_kwargs["create_graph"] = True
728
- return original(**new_kwargs)
729
-
730
- module.backward = patched_backward
731
- yield
732
- module.backward = original
733
-
734
- with _inner(torch.Tensor), _inner(torch.autograd):
823
+ try:
824
+ signature = inspect.signature(original)
825
+
826
+ @functools.wraps(original)
827
+ def patched_backward(*args, **kwargs):
828
+ new_kwargs = signature.bind(*args)
829
+ new_kwargs.apply_defaults()
830
+ new_kwargs = new_kwargs.arguments
831
+ new_kwargs.update(kwargs)
832
+ new_kwargs["create_graph"] = True
833
+ return original(**new_kwargs)
834
+
835
+ module.backward = patched_backward
836
+ yield
837
+ finally:
838
+ module.backward = original
839
+
840
+ with contextlib.ExitStack() as stack:
841
+ stack.enter_context(patch_module(torch.Tensor))
842
+ stack.enter_context(patch_module(torch.autograd))
735
843
  yield
736
844
 
737
845
 
@@ -743,6 +851,13 @@ class ExactHVPFailed(ValueError):
743
851
  pass
744
852
 
745
853
 
854
+ use_default = object()
855
+
856
+
857
+ def _tensor_key(x: Tensor):
858
+ return x.data_ptr(), x.numel(), x.dtype, x.device
859
+
860
+
746
861
  class StatefulOptimizer(torch.optim.Optimizer):
747
862
  """
748
863
  finite_differences saves memory, but needs more compute. (Alternative is true HVP)
@@ -755,7 +870,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
755
870
  compile_step: bool = False
756
871
  hessian_approx: bool = False
757
872
  precond_schedule: Union[Callable, float, None] = None
758
- stochastic_schedule: bool = False
873
+ stochastic_schedule: bool | Literal[use_default] = use_default
759
874
  finite_differences: bool = False
760
875
  fallback_to_finite_differences: bool = True
761
876
  _fallback_enabled: bool = False
@@ -765,18 +880,62 @@ class StatefulOptimizer(torch.optim.Optimizer):
765
880
  super().__init__(params, {**defaults, "foreach": foreach})
766
881
  self.use_ema = use_ema
767
882
  self.mapping = {}
768
- self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
769
- self._precond_rng = random.Random(0x12312)
883
+ self.mapping_inverse = {}
884
+
885
+ if self.stochastic_schedule is use_default:
886
+ stochastic_schedule = None
887
+ for group in self.param_groups:
888
+ new = group.get("stochastic_schedule", stochastic_schedule)
889
+ if stochastic_schedule is not None and new != stochastic_schedule:
890
+ raise ValueError("All parameter groups must have the same stochastic_schedule.")
891
+ stochastic_schedule = new
892
+ self.stochastic_schedule = stochastic_schedule
893
+
894
+ self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
895
+ self.precond_rng = random.Random(0x12312)
770
896
  self._is_preconditioning = None
771
897
 
772
898
  if self.hessian_approx and self.compile_step:
773
899
  raise ValueError("Hessian approximation can't be used with compile_step.")
774
900
 
901
+ self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
902
+ self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
903
+
904
+ def _store_stats(self, state_dict: dict[str, any]):
905
+ state_dict["heavyball"] = {
906
+ "inner_group": self.inner_group,
907
+ "precond_rng": pickle.dumps(self.precond_rng),
908
+ "use_ema": self.use_ema,
909
+ "ema_decay": self.ema_decay,
910
+ "compile_step": self.compile_step,
911
+ "hessian_approx": self.hessian_approx,
912
+ "precond_schedule": pickle.dumps(self.precond_schedule),
913
+ "stochastic_schedule": self.stochastic_schedule,
914
+ "fallback_to_finite_differences": self.fallback_to_finite_differences,
915
+ "_fallback_enabled": self._fallback_enabled,
916
+ "hvp_interval": self.hvp_interval,
917
+ }
918
+
919
+ def _load_stats(self, state_dict):
920
+ sd = state_dict.pop("heavyball", {})
921
+ for k, v in sd.items():
922
+ if k in ("precond_rng", "precond_schedule"):
923
+ v = pickle.loads(v)
924
+ setattr(self, k, v)
925
+
775
926
  def get_groups(self, group):
776
927
  return [group]
777
928
 
778
- def state_(self, arg: Tensor):
779
- return self.state[arg]
929
+ @functools.lru_cache(maxsize=None)
930
+ def state_(self, arg: Tensor, fail: bool = True):
931
+ if not fail and arg not in self.mapping:
932
+ return {}
933
+ if _tensor_key(arg) not in self.mapping_inverse:
934
+ self._init_mapping()
935
+ state_param, index = self.mapping_inverse[_tensor_key(arg)]
936
+ if state_param not in self.state:
937
+ self.state[state_param] = collections.defaultdict(dict)
938
+ return self.state[state_param][index]
780
939
 
781
940
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
782
941
  for p, g in zip(p_list, g_list):
@@ -786,6 +945,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
786
945
  old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
787
946
  mars_correction(g_list, old_gs, mars_gamma, beta)
788
947
 
948
+ def _init_mapping(self, group: dict | None = None):
949
+ if group is None:
950
+ for group in self.param_groups:
951
+ self._init_mapping(group)
952
+ return
953
+
954
+ for p in group["params"]:
955
+ if p not in self.mapping:
956
+ self.mapping[p] = p_views = merge_group(group, p)
957
+ for i, pv in enumerate(p_views):
958
+ self.mapping_inverse[_tensor_key(pv)] = (p, i)
959
+
789
960
  def split_p_and_g_in_group(
790
961
  self,
791
962
  group: dict,
@@ -805,10 +976,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
805
976
  yield p, grad
806
977
  continue
807
978
 
808
- if p in self.mapping:
809
- p_views = self.mapping[p]
810
- else:
811
- self.mapping[p] = p_views = merge_group(group, p)
979
+ self.mapping[p] = p_views = merge_group(group, p)
980
+ for i, pv in enumerate(p_views):
981
+ self.mapping_inverse[_tensor_key(pv)] = (p, i)
812
982
 
813
983
  vector = getattr(p, "vector", None)
814
984
  hessian_vector = getattr(p, "hessian_vector", None)
@@ -957,8 +1127,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
957
1127
  raise ValueError("Hessian approximation requires a closure.")
958
1128
  return None
959
1129
 
960
- step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
961
- if not hessian_approx or step % self.hvp_interval == 0:
1130
+ step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
1131
+ if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
962
1132
  with torch.enable_grad():
963
1133
  loss = closure()
964
1134
  return loss
@@ -997,12 +1167,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
997
1167
  if self.precond_schedule is None:
998
1168
  self._is_preconditioning = False
999
1169
  else:
1000
- self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
1170
+ self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
1001
1171
  loss = self._handle_closure(closure)
1002
1172
 
1003
1173
  # we assume that parameters are constant and that there are no excessive recompiles
1004
1174
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
1005
1175
  for group in self.param_groups:
1176
+ if "param_count" not in group:
1177
+ group["param_count"] = sum(p.numel() for p in group["params"])
1006
1178
  group["is_preconditioning"] = self._is_preconditioning
1007
1179
  self._step(group)
1008
1180
  if self.use_ema:
@@ -1105,7 +1277,7 @@ def fused_adam_(
1105
1277
  caution: bool,
1106
1278
  ):
1107
1279
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
1108
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
1280
+ beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
1109
1281
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
1110
1282
 
1111
1283
 
@@ -1184,7 +1356,7 @@ def fused_laprop_(
1184
1356
  eps: float = 1e-8,
1185
1357
  ):
1186
1358
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
1187
- beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
1359
+ beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
1188
1360
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
1189
1361
 
1190
1362
 
@@ -1203,7 +1375,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
1203
1375
 
1204
1376
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
1205
1377
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
1206
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
1378
+ beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
1207
1379
  _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
1208
1380
 
1209
1381
 
@@ -1233,11 +1405,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
1233
1405
 
1234
1406
 
1235
1407
  @decorator_knowngood
1236
- def stochastic_round_(ref: Tensor, source: Tensor):
1237
- if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1238
- return source
1239
- if ref.dtype != torch.bfloat16:
1240
- return source.to(ref.dtype)
1408
+ def stochastic_round_(ref: Tensor, source: Tensor | None = None):
1409
+ if source is not None:
1410
+ if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1411
+ return source
1412
+ if ref.dtype != torch.bfloat16:
1413
+ return source.to(ref.dtype)
1414
+ else:
1415
+ source = ref
1416
+ source = source.float()
1241
1417
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1242
1418
  result.add_(source.view(dtype=torch.int32))
1243
1419
  result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
@@ -1306,83 +1482,115 @@ def stable_exp(x: Tensor):
1306
1482
  return torch.where(x > 0, 1 / (-x).exp(), x.exp())
1307
1483
 
1308
1484
 
1485
+ def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
1486
+ # ln(mean(x ** pow) ** (1 / pow / 2))
1487
+ normalization = math.log(x.numel())
1488
+ x = x.double()
1489
+ x = x.abs()
1490
+ x = x.clamp(min=eps)
1491
+ x = x.log()
1492
+ x = x * pow
1493
+ x = x.flatten()
1494
+ x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
1495
+ x = x - normalization # sum -> mean (divide by x.numel() in log space)
1496
+ return x / pow / 2
1497
+
1498
+
1309
1499
  @decorator_knowngood
1310
1500
  def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
1311
1501
  # 1 / (mean(x ** pow) ** (1 / pow / 2))
1312
- log_x = x.double().abs().clamp(min=eps).log()
1313
- log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
1314
- return stable_exp(-log_mean_x_pow / pow / 2)
1502
+ return stable_exp(-_lse_mean(x, pow, eps))
1315
1503
 
1316
1504
 
1317
1505
  @decorator_knowngood
1318
1506
  def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
1319
1507
  # mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
1320
- log_x = x.double().abs().clamp(min=eps).log()
1321
- log_y = y.double().abs().clamp(min=eps).log()
1322
-
1323
- x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
1324
- x_normed = x_normed / pow0 / 2
1508
+ return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
1325
1509
 
1326
- y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
1327
- y_normed = y_normed / pow1 / 2
1328
1510
 
1329
- return stable_exp(x_normed - y_normed)
1511
+ class PrecondInitError(ValueError):
1512
+ pass
1330
1513
 
1331
1514
 
1332
- def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1515
+ def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
1333
1516
  automatic_scale = True
1334
1517
  manual_hint = " Set it manually using `precond_init_scale=0.1`"
1518
+ scale_scale = 1 if scale_scale is None else scale_scale
1335
1519
 
1336
1520
  if scale is not None:
1337
1521
  automatic_scale = False
1338
1522
  warn_once(
1339
1523
  "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
1340
1524
  )
1341
- if scale_scale is not None and scale_scale != 1:
1525
+ if scale_scale != 1:
1526
+ warn_once(
1527
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
1528
+ )
1529
+ if scale_power is not None:
1342
1530
  warn_once(
1343
- "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
1531
+ "precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
1344
1532
  )
1345
1533
  elif hessian_vector is None:
1346
1534
  scale = mean_root(grad, 4) * scale_scale
1347
1535
  else:
1348
1536
  scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1349
1537
 
1538
+ if automatic_scale:
1539
+ scale_power = 0.5 if scale_power is None else scale_power
1540
+ scale = scale**scale_power
1541
+
1350
1542
  if isinstance(scale, torch.Tensor):
1351
1543
  scale = scale.item() # slow, but necessary
1352
1544
 
1353
1545
  if np.isfinite(scale):
1354
- if scale > scale_max or scale < 1 / scale_max: # fallthrough to later checks
1546
+ if scale > scale_max: # fallthrough to later checks
1355
1547
  warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1356
1548
  else:
1357
1549
  return scale
1358
1550
 
1359
1551
  if not automatic_scale:
1360
- raise ValueError("The manually set precond_init_scale is not finite")
1552
+ raise PrecondInitError("The manually set precond_init_scale is not finite")
1361
1553
 
1362
1554
  for x in (grad, hessian_vector, vector):
1363
1555
  if x is None:
1364
1556
  continue
1365
- if torch.allclose(x, torch.zeros_like(x)).item():
1366
- raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1557
+ if torch.allclose(x, torch.zeros_like(x)):
1558
+ raise PrecondInitError(
1559
+ f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
1560
+ )
1367
1561
  if not torch.isfinite(x).all().item():
1368
- raise ValueError("Grad or HVP is not finite")
1562
+ raise PrecondInitError("Grad or HVP is not finite")
1369
1563
 
1370
1564
  if np.isfinite(scale):
1371
1565
  return scale
1372
1566
 
1373
- raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1567
+ raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
1374
1568
 
1375
1569
 
1376
- def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
1377
- scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1378
- U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1379
- V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1570
+ def init_lra(
1571
+ grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
1572
+ ):
1573
+ # "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
1574
+ # https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
1575
+ scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
1576
+ uv_scale = (param_count * (rank + eps)) ** -0.5
1577
+ U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
1578
+ V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
1380
1579
  d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
1381
1580
  return U, V, d
1382
1581
 
1383
1582
 
1384
1583
  def init_Q_exprs(
1385
- grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
1584
+ grad,
1585
+ scale,
1586
+ scale_scale,
1587
+ scale_power,
1588
+ max_size,
1589
+ min_ndim_triangular,
1590
+ memory_save_mode,
1591
+ hessian_vector,
1592
+ vector,
1593
+ dtype=None,
1386
1594
  ):
1387
1595
  """
1388
1596
  For a scalar or tensor `grad`, we initialize its preconditioner Q and
@@ -1391,21 +1599,13 @@ def init_Q_exprs(
1391
1599
  precond init scale computation from
1392
1600
  https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
1393
1601
  """
1394
- scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1395
- letters = string.ascii_lowercase + string.ascii_uppercase
1602
+ scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
1396
1603
  dtype = dtype if dtype is not None else grad.dtype
1397
1604
  shape = grad.shape
1398
1605
 
1399
1606
  if len(shape) == 0: # scalar
1400
1607
  Q = [scale * torch.ones_like(grad, dtype=dtype)]
1401
- exprA = ",->"
1402
- exprGs = [",->"]
1403
- exprP = ",,->"
1404
- return [Q, (exprA, tuple(exprGs), exprP)]
1405
-
1406
- # Tensor
1407
- if len(shape) > 13:
1408
- raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
1608
+ return Q
1409
1609
 
1410
1610
  scale = scale ** (1 / len(shape))
1411
1611
 
@@ -1418,6 +1618,9 @@ def init_Q_exprs(
1418
1618
  sorted_shape = sorted(shape)
1419
1619
  if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
1420
1620
  dim_diag[_max_idx(shape)] = True
1621
+ elif memory_save_mode == "one_triu":
1622
+ shape_ranks = np.argsort(np.argsort(shape)) # ranks
1623
+ dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
1421
1624
  elif memory_save_mode == "all_diag":
1422
1625
  dim_diag = [True for _ in shape]
1423
1626
  else:
@@ -1427,66 +1630,90 @@ def init_Q_exprs(
1427
1630
  )
1428
1631
 
1429
1632
  Q = []
1430
- piece1A, piece2A, piece3A = ([], "", "")
1431
- exprGs = []
1432
- piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
1433
1633
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
1434
1634
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
1435
1635
  # use diagonal matrix as preconditioner for this dim
1436
1636
  Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
1437
-
1438
- piece1A.append(letters[i])
1439
- piece2A = piece2A + letters[i]
1440
- piece3A = piece3A + letters[i]
1441
- piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1442
- subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
1443
- exprGs.append(subscripts)
1444
- piece1P.append(letters[i + 13])
1445
- piece2P.append(letters[i + 13])
1446
- piece3P = piece3P + letters[i + 13]
1447
- piece4P = piece4P + letters[i + 13]
1448
1637
  else:
1449
1638
  # use triangular matrix as preconditioner for this dim
1450
1639
  Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
1451
- piece1A.append(letters[i] + letters[i + 13])
1452
- piece2A = piece2A + letters[i + 13]
1453
- piece3A = piece3A + letters[i]
1454
- piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1455
- piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1456
- subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
1457
- exprGs.append(subscripts)
1458
- a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1459
- piece1P.append(a + b)
1460
- piece2P.append(a + c)
1461
- piece3P = piece3P + c
1462
- piece4P = piece4P + b
1463
-
1464
- exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
1465
- exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
1466
- return [Q, (exprA, tuple(exprGs), exprP)]
1640
+ return Q
1467
1641
 
1468
1642
 
1469
- @decorator
1470
- def psgd_balance_Q(Q_in):
1471
- norms = torch.stack([q.norm(float("inf")) for q in Q_in])
1472
- geometric_mean = norms.log().mean().exp()
1473
- norms = geometric_mean / norms
1474
- torch._foreach_mul_(Q_in, list(norms))
1643
+ @decorator_knowngood
1644
+ def psgd_balance_Q(Q):
1645
+ norms = [promote(q.norm(float("inf"))).log() for q in Q]
1646
+ geometric_mean = sum([n for n in norms]) / len(Q)
1647
+ for q, n in zip(Q, norms):
1648
+ q *= (geometric_mean - n).exp()
1475
1649
 
1476
1650
 
1477
- @decorator
1478
- def psgd_balance_lra(U: Tensor, V: Tensor):
1479
- u_norm = promote(torch.linalg.vector_norm(U))
1480
- v_norm = promote(torch.linalg.vector_norm(V))
1481
- scale = (u_norm / v_norm) ** 0.5
1482
- U.div_(scale)
1483
- V.mul_(scale)
1651
+ @decorator_knowngood
1652
+ def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
1653
+ u_norm = sum(u.square().sum().double() for u in U)
1654
+ v_norm = sum(v.square().sum().double() for v in V)
1655
+ scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
1656
+ scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
1657
+ stochastic_multiply_(U, [1 / scale] * len(U))
1658
+ stochastic_multiply_(V, [scale] * len(V))
1659
+ return multi_flatten((U, 1), (V, 1), (d, 0))
1484
1660
 
1485
1661
 
1486
1662
  @decorator
1487
1663
  def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1488
1664
  dtype = min_dtype([U, V, x])
1489
- return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1665
+ return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1666
+
1667
+
1668
+ @decorator_knowngood
1669
+ def _compilable_d_step(
1670
+ d: Tensor,
1671
+ d_orig: List[Tensor],
1672
+ invQtv: Tensor,
1673
+ vector: Tensor,
1674
+ inverse_precond_vector: Tensor,
1675
+ hessian_vector: Tensor,
1676
+ precond_hessian_vector: Tensor,
1677
+ eps: Tensor,
1678
+ step: Tensor,
1679
+ delayed: bool,
1680
+ ):
1681
+ precond_hessian_vector = promote(precond_hessian_vector)
1682
+ hessian_vector = promote(hessian_vector)
1683
+ vector = promote(vector)
1684
+ inverse_precond_vector = promote(inverse_precond_vector)
1685
+ invQtv = promote(invQtv)
1686
+ inverse_precond_vector = invQtv - inverse_precond_vector
1687
+
1688
+ nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
1689
+
1690
+ """
1691
+ 1) Sketching
1692
+ 1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
1693
+ 1.2) reduced-precision selection of largest element (halves memory traffic)
1694
+ 2) Computation
1695
+ 2.1) select relevant indices
1696
+ 2.2) redo 1.1 in double precision for scalar values
1697
+ 2.3) return high-precision normalized step-size
1698
+ overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
1699
+ improving precision
1700
+ """
1701
+ a0 = promote(d) * precond_hessian_vector
1702
+ a1 = vector
1703
+ b0 = inverse_precond_vector / promote(d)
1704
+ b1 = hessian_vector
1705
+
1706
+ divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
1707
+ idx = divisor.bfloat16().flatten().argmax()
1708
+ a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
1709
+ b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
1710
+ divisor = (a * b).sqrt().clamp(min=eps)
1711
+ step = -step / divisor
1712
+
1713
+ # fused update(s)
1714
+ apply_flat_add(d_orig, nablaD, step)
1715
+ if not delayed:
1716
+ copy_stochastic_(d, promote(d) - nablaD * step)
1490
1717
 
1491
1718
 
1492
1719
  def update_lra_precond_(
@@ -1498,13 +1725,14 @@ def update_lra_precond_(
1498
1725
  eps: float,
1499
1726
  step: float,
1500
1727
  delayed: bool,
1728
+ precond_u: bool,
1501
1729
  ):
1502
1730
  """
1503
1731
  Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1504
1732
  """
1505
1733
  U_orig, V_orig, d_orig = U, V, d
1506
1734
 
1507
- U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
1735
+ U, V, d = _lra_flatten_and_balance(U, V, d)
1508
1736
 
1509
1737
  dtype = min_dtype([U, V, vector, hessian_vector])
1510
1738
  U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
@@ -1512,10 +1740,10 @@ def update_lra_precond_(
1512
1740
  eps = scalar_guard(eps, vector)
1513
1741
 
1514
1742
  Qh = low_rank_mm(U, V, d * hessian_vector)
1515
- Ph = d * low_rank_mm(V, U, Qh)
1743
+ Ph = low_rank_mm(V, U, Qh)
1516
1744
  rank = U.size(1)
1517
1745
 
1518
- VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
1746
+ VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
1519
1747
  I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
1520
1748
  IpVtU = I + VtU
1521
1749
  invQtv = vector / d
@@ -1533,47 +1761,39 @@ def update_lra_precond_(
1533
1761
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1534
1762
 
1535
1763
  invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1536
- invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1537
- invPv = invPv / d
1538
-
1539
- nablaD = Ph * hessian_vector - vector * invPv
1540
- divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
1541
- divisor = divisor.add(eps).sqrt().max()
1542
- d_step = step / divisor
1764
+ invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1543
1765
 
1544
- apply_flat_add(d_orig, d * nablaD, -d_step)
1766
+ eps, step = scalar_guard(eps, step, vector)
1767
+ _compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
1545
1768
 
1546
1769
  a, b = Qh, invQtv
1547
1770
 
1548
- precond_u = random.random() < 0.5 # update either U or V, not both at the same time
1549
1771
  precond = V if precond_u else U
1550
- atV = torch.einsum("b,br->r", a, precond) # o == one
1551
- btV = torch.einsum("b,br->r", b, precond)
1552
- atVVt = torch.einsum("r,br->b", atV, precond)
1553
- btVVt = torch.einsum("r,br->b", btV, precond)
1554
- precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
1772
+ atV = compiled_einsum("b,br->r", a, precond) # o == one
1773
+ btV = compiled_einsum("b,br->r", b, precond)
1774
+ atVVt = compiled_einsum("r,br->b", atV, precond)
1775
+ btVVt = compiled_einsum("r,br->b", btV, precond)
1776
+ precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
1555
1777
  if precond_u:
1556
- a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
1557
- b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
1778
+ a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
1779
+ b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
1558
1780
  else:
1559
- a = a + torch.einsum("br,r->b", V, atV)
1560
- b = b + torch.einsum("br,r->b", V, btV)
1561
- a = torch.einsum("b,r->br", a, atV)
1562
- b = torch.einsum("b,r->br", b, btV)
1781
+ a = a + compiled_einsum("br,r->b", V, atV)
1782
+ b = b + compiled_einsum("br,r->b", V, btV)
1783
+ a = compiled_einsum("b,r->br", a, atV)
1784
+ b = compiled_einsum("b,r->br", b, btV)
1563
1785
  apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
1564
-
1565
1786
  if not delayed:
1566
- stochastic_add_([d], [d * nablaD], -d_step)
1567
1787
  stochastic_add_([U if precond_u else V], [b - a], precond_step)
1568
1788
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1569
1789
 
1570
1790
 
1571
- def lra_precond(U, V, d, g):
1791
+ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
1572
1792
  """
1573
1793
  As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
1574
1794
  """
1575
- g = low_rank_mm(U, V, d * g)
1576
- return d * low_rank_mm(V, U, g)
1795
+ new_g = low_rank_mm(U, V, d * g)
1796
+ return d * low_rank_mm(V, U, new_g)
1577
1797
 
1578
1798
 
1579
1799
  @decorator_knowngood
@@ -1584,16 +1804,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
1584
1804
 
1585
1805
 
1586
1806
  @decorator_knowngood
1587
- def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
1588
- update = lra_precond(U, V, d, update)
1807
+ def _compilable_lra_update_(
1808
+ params: List[Tensor],
1809
+ update: List[Tensor],
1810
+ U: Tensor,
1811
+ V: Tensor,
1812
+ d: Tensor,
1813
+ lr: Tensor,
1814
+ decay: Tensor,
1815
+ caution: bool,
1816
+ grads: List[Tensor],
1817
+ ):
1818
+ update = lra_precond(U, V, d, flatten(update))
1589
1819
  start = 0
1590
1820
  update = update.flatten()
1591
- for p in params:
1821
+ for p, g in zip(params, grads):
1592
1822
  size = p.numel()
1593
- copy_stochastic_(p, update[start : start + size].view_as(p))
1823
+ update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
1594
1824
  start += size
1595
1825
 
1596
1826
 
1827
+ def apply_lra_update(
1828
+ params: List[Tensor],
1829
+ update: Tensor,
1830
+ U: Tensor,
1831
+ V: Tensor,
1832
+ d: Tensor,
1833
+ lr: float,
1834
+ decay: float,
1835
+ caution: bool,
1836
+ grads: List[Tensor],
1837
+ ):
1838
+ params, grads = list_guard(params, grads)
1839
+ lr, decay = scalar_guard(lr, decay, params[0])
1840
+ _compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
1841
+
1842
+
1597
1843
  @decorator_knowngood
1598
1844
  def apply_flat_update(params: List[Tensor], update: Tensor):
1599
1845
  start = 0
@@ -1604,6 +1850,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
1604
1850
  start += size
1605
1851
 
1606
1852
 
1853
+ @decorator_knowngood
1854
+ def zero_(x: List[Tensor]):
1855
+ for i in x:
1856
+ i.zero_()
1857
+
1858
+
1607
1859
  @decorator_knowngood
1608
1860
  def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1609
1861
  start = 0
@@ -1629,7 +1881,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1629
1881
  @decorator_knowngood
1630
1882
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1631
1883
  last_dim = x[0].shape[-remaining:] if remaining else []
1632
- return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
1884
+ return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
1885
+
1886
+
1887
+ @decorator_knowngood
1888
+ def multi_flatten(*xs: Tuple[List[Tensor], int]):
1889
+ return [flatten(x, i) for x, i in xs]
1633
1890
 
1634
1891
 
1635
1892
  @decorator_knowngood
@@ -1645,149 +1902,564 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1645
1902
 
1646
1903
  def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1647
1904
  md = min_dtype(args)
1648
- return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1905
+ return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1649
1906
 
1650
1907
 
1651
1908
  @decorator_knowngood
1652
1909
  def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
1653
1910
  triangular_qs = []
1911
+ conjB = promote(conjB)
1654
1912
  for i, q in enumerate(Qs):
1655
1913
  q = promote(q)
1656
1914
  if q.dim() <= 1:
1657
- shape = [1] * conjB.ndim
1658
- shape[i] = -1
1659
- conjB /= q.view(shape)
1915
+ if conjB.ndim == 0:
1916
+ conjB = conjB / q
1917
+ else:
1918
+ shape = [1] * conjB.ndim
1919
+ shape[i] = -1
1920
+ conjB = conjB / q.view(shape)
1660
1921
  else:
1661
1922
  triangular_qs.append((i, q))
1662
- return triangular_qs
1923
+ return triangular_qs, conjB
1663
1924
 
1664
1925
 
1665
1926
  @decorator_knowngood
1666
- def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int, new_shape: int):
1927
+ def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
1928
+ solved = solved.reshape(transposed_shape)
1929
+ solved = solved.transpose(-1, last_dim)
1667
1930
  solved = solved.reshape(original_shape)
1668
- solved.transpose(last_dim, -1)
1669
- return solved.reshape(new_shape).contiguous()
1931
+ solved = solved.transpose(-1, new_dim)
1932
+ return solved.contiguous(), solved.shape
1933
+
1934
+
1935
+ def ndim_tuple(Q: list[Tensor]) -> tuple:
1936
+ return tuple(q.ndim for q in Q)
1670
1937
 
1671
1938
 
1672
- def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1673
- order = G.dim()
1674
- if order > 1:
1675
- conjB = conjB.view_as(G).permute(*range(1, order), 0)
1676
- conjB = conjB.to(promote(G.dtype))
1939
+ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
1940
+ if conjB is None:
1941
+ conjB = torch.randn_like(G)
1942
+ exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
1677
1943
  A = casted_einsum(exprA, *Q, G)
1678
1944
  solve = torch.compiler.disable(torch.linalg.solve_triangular)
1679
- original_shape = conjB.shape
1945
+ transposed_shape = original_shape = conjB.shape
1680
1946
  prev_i = -1
1681
- for i, tri_q in _psgd_calc_scalars_(Q, conjB):
1682
- conjB = _reshape_conjB(conjB, original_shape, prev_i, [-1, tri_q.size(0)])
1947
+ qs, conjB = _psgd_calc_scalars_(Q, conjB)
1948
+ for i, tri_q in qs:
1949
+ conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
1683
1950
  prev_i = i
1684
1951
  conjB = solve(tri_q, conjB, upper=True, left=False)
1685
- conjB = _reshape_conjB(conjB, original_shape, prev_i, original_shape)
1952
+ conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
1686
1953
  return A, conjB
1687
1954
 
1688
1955
 
1689
1956
  @decorator_knowngood
1690
- def _max_select(to_index: Tensor, to_argmax: Tensor):
1691
- idx = to_argmax.argmax()
1692
- return to_index.index_select(1, idx).flatten().contiguous()
1957
+ def _random_projection(x: Tensor, scale: Optional[Tensor]):
1958
+ if scale is None:
1959
+ scale = x.norm(float("inf")).clamp(min=1e-8)
1960
+ k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
1961
+ norm = x.square().sum(0)
1962
+ indices = torch.topk(norm, k, largest=True).indices
1963
+ return x.index_select(1, indices).contiguous() / scale, scale
1693
1964
 
1694
1965
 
1695
- def psgd_lb(A: Tensor, max_abs: Tensor):
1696
- A /= max_abs
1697
- x = _max_select(A, torch.einsum("ij,ij->j", A, A))
1698
- x = torch.einsum("i,ij->j", x, A)
1699
- x /= x.norm()
1700
- x = torch.einsum("j,kj->k", x, A)
1701
- x = x.norm()
1702
- x *= max_abs
1703
- return x
1966
+ def max_singular_value_exact(A, use_lobpcg: bool = False):
1967
+ try:
1968
+ if use_lobpcg:
1969
+ A = A @ A.T
1970
+ eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
1971
+ return eigval[0].sqrt()
1972
+ else:
1973
+ return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
1974
+ except (torch.linalg.LinAlgError, RuntimeError):
1975
+ return max_singular_value_power_iter(promote(A), iterations=2)
1704
1976
 
1705
1977
 
1706
1978
  @decorator_knowngood
1707
- def _subtract_from_line_(state: Tensor, term: Tensor):
1708
- stochastic_add_([state], [triu_to_line([term])[0][1]], -1)
1979
+ def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
1980
+ """
1981
+ Rayleigh quotient of row with the largest norm + optional power iterations
1982
+ """
1983
+ x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
1984
+ x_norm = promote(x_norm)
1985
+
1986
+ def _inner():
1987
+ A = A_outer
1988
+ x = A.index_select(0, max_idx).flatten().contiguous()
1989
+ A = stochastic_round_(A / x_norm)
1990
+ x = x / x_norm
1991
+
1992
+ def _mv(x):
1993
+ return promote(A.T.mv(A.mv(stochastic_round_(x))))
1994
+
1995
+ for _ in range(iterations):
1996
+ # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
1997
+ x = F.normalize(_mv(x), dim=0)
1998
+ out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
1999
+ return out.squeeze().clone()
2000
+
2001
+ return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
1709
2002
 
1710
2003
 
1711
2004
  @decorator_knowngood
1712
- def _prescale_term_(term1: Tensor, fac: Tensor, norm: Tensor, lower_bound: Tensor):
1713
- out = term1.float().triu() * fac
1714
- out = out / torch.where(norm > 0, lower_bound, norm).clamp(tiny_bf16)
1715
- copy_stochastic_(term1, out)
2005
+ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
2006
+ """
2007
+ Adapted from @evanatyourservice
2008
+ """
2009
+ Y, max_abs = _random_projection(A, max_abs)
2010
+ Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
2011
+ Q = Q / max_abs
2012
+ Z = A.T @ Q
2013
+ W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
2014
+ sketch_norm = max_singular_value_exact(Z.T @ W)
2015
+ return sketch_norm * max_abs
2016
+
2017
+
2018
+ def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
2019
+ if A.ndim <= 2:
2020
+ return max_singular_value(A, max_svd, use_cholesky, power_iter)
2021
+
2022
+ base = einsum_base[: A.ndim]
2023
+ A16 = stochastic_round_(A)
2024
+ squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
2025
+ svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
2026
+ svds = torch.stack(svds)
2027
+ return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
1716
2028
 
1717
2029
 
1718
2030
  @decorator_knowngood
1719
- def _compilable_stochastic_multiply_div_(x: Tensor, fac: Tensor, y: Tensor, z: Tensor):
1720
- copy_stochastic_(x, promote(x) * promote(fac) * promote(y) / promote(z).clamp(min=tiny_bf16))
2031
+ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
2032
+ if A.ndim < 2:
2033
+ return A.abs().max()
2034
+ if A.ndim > 2:
2035
+ raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
2036
+ if min(A.shape) <= max_svd:
2037
+ return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
2038
+ if use_cholesky or power_iter < 0:
2039
+ return max_singular_value_cholesky(A)
2040
+ return max_singular_value_power_iter(A, None, iterations=power_iter)
1721
2041
 
1722
2042
 
1723
2043
  @decorator_knowngood
1724
- def _compilable_add_sub_(x: Tensor, y: Tensor):
1725
- x = promote(x)
1726
- y = promote(y)
1727
- return x - y, x + y
2044
+ def clamped_max_singular_value(
2045
+ A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
2046
+ ) -> Tensor:
2047
+ norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
2048
+ out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
2049
+ return out.clamp(min=min)
2050
+
2051
+
2052
+ @decorator_knowngood
2053
+ def min_singular_value(
2054
+ A: Tensor,
2055
+ power_iter: int = 5,
2056
+ safety: float = 1.05,
2057
+ max_svd: int = 32,
2058
+ ):
2059
+ if A.ndim < 2:
2060
+ return A.abs().min()
2061
+
2062
+ n = A.size(0)
2063
+ if n <= max_svd:
2064
+ try:
2065
+ eigs = torch.linalg.eigvalsh(promote(A))
2066
+ return eigs.min().to(A.dtype)
2067
+ except torch.linalg.LinAlgError:
2068
+ pass
2069
+
2070
+ lambda_max_hat = max_singular_value(A, power_iter=power_iter)
2071
+ lambda_upper = lambda_max_hat * safety
2072
+
2073
+ row_norms = A.norm(dim=1)
2074
+ norm, idx = row_norms.min(dim=0)
2075
+ v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
2076
+
2077
+ v = v / promote(v.norm())
2078
+ for _ in range(power_iter):
2079
+ v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
2080
+ v = v / promote(v.norm())
2081
+ mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
2082
+
2083
+ lambda_min_hat = lambda_upper - mu_hat
2084
+
2085
+ def _approx():
2086
+ mu = A.trace() / n
2087
+ sigma_square = A.square().sum() / n - mu**2
2088
+ return mu - (sigma_square / (n - 1)).sqrt()
2089
+
2090
+ return cond(
2091
+ (~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
2092
+ ).squeeze()
2093
+
2094
+
2095
+ @decorator_knowngood
2096
+ def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
2097
+ if isinstance(Q[0], tuple):
2098
+ psgd_balance_Q([o[1] for o in Q])
2099
+ return line_to_triu(Q, symmetric_output)
2100
+ psgd_balance_Q(Q)
2101
+ return Q
2102
+
2103
+
2104
+ @functools.lru_cache(maxsize=None)
2105
+ def calcG_expr(q_dim, g_dim):
2106
+ exprs = []
2107
+ base = einsum_base[:g_dim]
2108
+ for i, q in enumerate(q_dim):
2109
+ new = list(base)
2110
+ if q == 2:
2111
+ new[i] = "Z"
2112
+ out = f"{base[i]}Z"
2113
+ else:
2114
+ out = base[i]
2115
+ exprs.append(f"{base},{''.join(new)}->{out}")
2116
+ return exprs
2117
+
2118
+
2119
+ def eye_like(x: Tensor):
2120
+ if x.ndim < 2:
2121
+ return torch.ones_like(x)
2122
+ assert x.ndim == 2
2123
+ assert x.size(0) == x.size(1)
2124
+ return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
2125
+
2126
+
2127
+ @decorator_knowngood
2128
+ def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
2129
+ """
2130
+ Idea:
2131
+ G should be zeroth power. So, all Qs together should approximate the G's inverse.
2132
+ Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
2133
+ Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
2134
+ then backprop to L/R jointly.
2135
+ This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
2136
+
2137
+ `psgd_precond_grad` computes LGR for the general (n-dimensional) case
2138
+ `exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
2139
+ Args:
2140
+ G: Gradient that should be orthogonalized
2141
+ Q: List of preconditioner tensors.
2142
+
2143
+ Returns:
2144
+ - List of gradients with respect to Q (d_Q).
2145
+ """
2146
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2147
+
2148
+ G16 = stochastic_round_(G)
2149
+ Q16 = [stochastic_round_(q) for q in Q]
2150
+ P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
2151
+
2152
+ d_P = torch.zeros_like(G)
2153
+ base = einsum_base[: G.ndim]
2154
+ for i, exprG in enumerate(exprGs):
2155
+ pp = compiled_einsum(exprG, P, P)
2156
+ error = pp - eye_like(pp)
2157
+ dim = einsum_base[i]
2158
+ if pp.ndim == 2:
2159
+ new = dim.upper()
2160
+ prec = f"{new}{dim}"
2161
+ else:
2162
+ new = dim
2163
+ prec = dim
2164
+ d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
2165
+
2166
+ d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
2167
+ grads = []
2168
+ for i, exprG in enumerate(exprGs):
2169
+ new_q = Q16[:]
2170
+ new_q[i] = eye_like(new_q[i])
2171
+ pq = psgd_precond_grad(G16, new_q)
2172
+ grad = compiled_einsum(exprG, pq, d_P)
2173
+ if grad.ndim == 2:
2174
+ grad = (grad + grad.T) / 2
2175
+ grads.append(grad)
2176
+
2177
+ return grads, P.to(G.dtype)
2178
+
2179
+
2180
+ def _inverse_initial_guess(gg):
2181
+ n = gg.shape[0]
2182
+
2183
+ sigma_max = promote(gg.norm())
2184
+
2185
+ trace_gg = promote(torch.trace(gg))
2186
+ sigma_min_approx = trace_gg / (n * sigma_max)
2187
+
2188
+ return sigma_max, sigma_min_approx
2189
+
2190
+
2191
+ @decorator_knowngood
2192
+ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
2193
+ k = torch.arange(degree, dtype=torch.float64, device=device)
2194
+ rotation = (2 * k + 1) * math.pi / (2 * degree)
2195
+ f = (rotation.cos() + 1 + eps) ** -0.5
2196
+ rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
2197
+ coeff0 = f.sum() / degree
2198
+ coeffs = f @ rotation * 2 / degree
2199
+ return coeff0.float(), coeffs.float()
2200
+
2201
+
2202
+ @decorator_knowngood
2203
+ def _psgd_default_preconditioner_grad(
2204
+ terms: List[Tuple[Tensor, Tensor]],
2205
+ Q: List[Tensor],
2206
+ ) -> List[Tensor]:
2207
+ out = []
2208
+ for q, (x, y) in zip(Q, terms):
2209
+ x = promote(x)
2210
+ y = promote(y)
2211
+ update = x - y
2212
+ if q.ndim < 2:
2213
+ update = q * update
2214
+ else:
2215
+ update = (q @ update).triu()
2216
+ out.append(update)
2217
+ return out
1728
2218
 
1729
2219
 
1730
2220
  @decorator
1731
- def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
2221
+ def psgd_update_precond(
2222
+ G: Tensor,
2223
+ precond_lr: float,
2224
+ oq: "TriuOrLine",
2225
+ store_triu_as_line: bool,
2226
+ velocity: Optional[List[Tensor]],
2227
+ beta2: float,
2228
+ ortho_method: Optional[str],
2229
+ V: Tensor,
2230
+ running_lower_bound: List[Tensor],
2231
+ lower_bount_beta: float,
2232
+ power_iter: int,
2233
+ ) -> None:
1732
2234
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1733
- exprA, exprGs, _ = exprs
1734
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1735
- precond_lr = scalar_guard(precond_lr, G)
1736
-
1737
- for q, exprG, o in zip(Q, exprGs, oq):
1738
- term1 = torch.einsum(exprG, A, A)
1739
- term2 = torch.einsum(exprG, conjB, conjB)
1740
- term1, term2 = _compilable_add_sub_(term1, term2)
1741
- norm = term2.norm(float("inf"))
1742
- if q.dim() < 2:
1743
- _compilable_stochastic_multiply_div_(term1, precond_lr, q, norm)
2235
+ Q = _balance_to_triu(oq)
2236
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2237
+ precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
2238
+
2239
+ A, conjB = psgd_calc_A_and_conjB(G, Q, V)
2240
+ terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
2241
+ del A, conjB, V
2242
+ updates = _psgd_default_preconditioner_grad(terms, Q)
2243
+ _psgd_precond_update_(
2244
+ updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2245
+ )
2246
+ return None
2247
+
2248
+
2249
+ @decorator_knowngood
2250
+ def bf16_matmul(x: Tensor, y: Tensor):
2251
+ return (promote(x) @ promote(y)).to(x.dtype)
2252
+
2253
+
2254
+ def if_iscompiling(fn):
2255
+ base = getattr(torch, fn.__name__, None)
2256
+
2257
+ def _fn(x):
2258
+ if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2259
+ return base(x)
2260
+ return fn(x)
2261
+
2262
+ return _fn
2263
+
2264
+
2265
+ @if_iscompiling
2266
+ def while_loop(cond, body, state):
2267
+ """
2268
+ dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
2269
+ useful for debugging
2270
+ """
2271
+ while cond(*state).item():
2272
+ state = body(*state)
2273
+ return state
2274
+
2275
+
2276
+ @if_iscompiling
2277
+ def cond(cond, true_fn, false_fn):
2278
+ """
2279
+ dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
2280
+ useful for debugging
2281
+ """
2282
+
2283
+ if cond.item():
2284
+ return true_fn()
2285
+ return false_fn()
2286
+
2287
+
2288
+ def cond_n(cond_val: Tensor, *fns):
2289
+ fns = list(fns)
2290
+ fn = fns.pop(0)
2291
+ if not fns:
2292
+ return fn
2293
+ return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
2294
+
2295
+
2296
+ @decorator_knowngood
2297
+ def _psgd_precond_update_(
2298
+ matmuled: List[Optional[Tensor]],
2299
+ Q: "TriuOrLine",
2300
+ running_lower_bound: List[Tensor],
2301
+ lower_bount_beta: Tensor,
2302
+ precond_lr: Tensor,
2303
+ store_triu_as_line: bool,
2304
+ power_iter: int,
2305
+ ):
2306
+ for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
2307
+ if isinstance(oq, tuple):
2308
+ oq = oq[1]
2309
+
2310
+ q = promote(oq)
2311
+ if update.ndim < 2:
2312
+ lb = update.norm(float("inf"))
1744
2313
  else:
1745
- lower_bound = psgd_lb(term2, norm)
1746
- _prescale_term_(term1, precond_lr, lower_bound, norm)
1747
- torch.mm(term1, q.to(term1.dtype), out=term1)
1748
- if store_triu_as_line:
1749
- _subtract_from_line_(q, term1)
2314
+ lb = max_singular_value(update, power_iter=power_iter)
2315
+ update = promote(update)
2316
+ if store_triu_as_line:
2317
+ update = triu_to_line([update])[0][1]
2318
+
2319
+ lb = promote(lb)
2320
+ lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
2321
+ copy_stochastic_(lb_state, lb)
2322
+ copy_stochastic_(oq, q - update / lb * precond_lr)
2323
+
2324
+
2325
+ @decorator_knowngood
2326
+ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
2327
+ """
2328
+ I: Identity
2329
+ U: Update / gg / target
2330
+ Q: q, preconditioner
2331
+ scale: scalar scale
2332
+ ---
2333
+ U = T * scale - I
2334
+ F = I - U # = 2I - U * scale
2335
+ O = F @ Q @ F - Q
2336
+ """
2337
+ out = []
2338
+ for gg, q in zip(GG, Q):
2339
+ if gg.ndim < 2:
2340
+ scale = max(1, gg.numel()) / numel
2341
+ target = promote(gg)
2342
+ update = target * scale - 1
2343
+ out.append(q - (1 - update) * q * (1 - update))
1750
2344
  else:
1751
- stochastic_add_(o, term1, -1)
2345
+ scale = gg.size(0) / numel
2346
+ gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
2347
+ update = q - gg @ q @ gg
2348
+ out.append(update + update.T) # make matrix symmetric
2349
+ return out
2350
+
2351
+
2352
+ @decorator
2353
+ def inverse_free_psgd_update_precond(
2354
+ G: Tensor,
2355
+ precond_lr: float,
2356
+ oq: List[Tensor],
2357
+ store_triu_as_line: bool,
2358
+ velocity: Optional[List[Tensor]],
2359
+ beta2: float,
2360
+ ortho_method: Optional[str],
2361
+ V: None,
2362
+ running_lower_bound: List[Tensor],
2363
+ lower_bount_beta: float,
2364
+ power_iter: int,
2365
+ ) -> Tensor:
2366
+ """Update Kronecker product preconditioner Q with pair (V, G)."""
2367
+ assert V is None
2368
+ assert ortho_method is None
2369
+ assert velocity is None
2370
+ del V, ortho_method, velocity
2371
+
2372
+ Q = _balance_to_triu(oq, True)
2373
+ precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
2374
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2375
+
2376
+ G = psgd_precond_grad(G, Q)
2377
+ terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
2378
+ matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
2379
+ _psgd_precond_update_(
2380
+ matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2381
+ )
2382
+ return G
1752
2383
 
1753
2384
 
1754
2385
  @decorator_knowngood
1755
- def _compilable_l2_clip_(x, clip_at):
1756
- ref = x
1757
- x = list(map(promote, x))
1758
- norm = torch._foreach_norm(x)
1759
- torch._foreach_maximum_(norm, clip_at)
1760
- out = torch._foreach_div(x, norm)
1761
- return stochastic_round_list_(ref, out)
2386
+ def _clip(x, norm, clip_at, eps=1e-8):
2387
+ x32 = promote(x)
2388
+ # (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
2389
+ norm = clip_at / norm.clamp(min=max(clip_at, eps))
2390
+ x32 = x32 * norm
2391
+ copy_stochastic_(x, x32)
2392
+
2393
+
2394
+ @decorator_knowngood
2395
+ def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
2396
+ for x in xs:
2397
+ _clip(x, promote(x).norm(), clip_at, eps)
1762
2398
 
1763
2399
 
1764
2400
  def l2_normalization_(x, clip_at: float = 1e-8):
1765
2401
  x = list_guard(x)
1766
- return _compilable_l2_clip_(x, clip_at)
2402
+ _compilable_l2_clip_(x, clip_at)
2403
+ return x
1767
2404
 
1768
2405
 
1769
2406
  def l2_clip_(x, clip_at: float = 1.0):
1770
2407
  x = list_guard(x)
1771
- return _compilable_l2_clip_(x, clip_at)
2408
+ _compilable_l2_clip_(x, clip_at)
2409
+ return x
1772
2410
 
1773
2411
 
1774
2412
  @decorator_knowngood
1775
- def _compilable_rmsnorm_clip_(x, clip_at):
1776
- x = list(map(promote, x))
1777
- norm = torch._foreach_norm(x)
1778
- norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1779
- torch._foreach_maximum_(norm, clip_at)
1780
- return torch._foreach_div(x, norm)
2413
+ def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
2414
+ for x in xs:
2415
+ _clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
1781
2416
 
1782
2417
 
1783
2418
  def rmsnorm_clip_(x, clip_at: float = 1.0):
1784
2419
  x = list_guard(x)
1785
- return _compilable_rmsnorm_clip_(x, clip_at)
2420
+ _compilable_rmsnorm_clip_(x, clip_at)
2421
+ return x
2422
+
2423
+
2424
+ @decorator_knowngood
2425
+ def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
2426
+ norm = 0
2427
+ numel = sum([i.numel() for i in x])
2428
+ for i in x:
2429
+ norm += promote(i).square().sum()
2430
+ norm = (norm / numel) ** 0.5
2431
+ scalar = clip_at / norm.clamp(min=max(clip_at, eps))
2432
+ stochastic_multiply_(x, scalar)
2433
+
2434
+
2435
+ def global_rmsnorm_clip(x, clip_at: float = 1.0):
2436
+ x = list_guard(x)
2437
+ clip_at = scalar_guard(clip_at, x[0])
2438
+ _compilable_global_rmsnorm_clip_(x, clip_at)
2439
+ return x
2440
+
2441
+
2442
+ @decorator_knowngood
2443
+ def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
2444
+ norm = 0
2445
+ for i in x:
2446
+ norm += promote(i).square().sum()
2447
+ norm = norm**0.5
2448
+ scalar = clip_at / norm.clamp(min=max(clip_at, eps))
2449
+ stochastic_multiply_(x, scalar)
2450
+
2451
+
2452
+ def global_l2norm_clip(x, clip_at: float = 1.0):
2453
+ x = list_guard(x)
2454
+ clip_at = scalar_guard(clip_at, x[0])
2455
+ _compilable_global_l2norm_clip_(x, clip_at)
2456
+ return x
1786
2457
 
1787
2458
 
1788
2459
  def rmsnorm_normalize_(x, clip_at: float = 1e-6):
1789
2460
  x = list_guard(x)
1790
- return _compilable_rmsnorm_clip_(x, clip_at)
2461
+ _compilable_rmsnorm_clip_(x, clip_at)
2462
+ return x
1791
2463
 
1792
2464
 
1793
2465
  @decorator_knowngood
@@ -1920,35 +2592,25 @@ def triu_to_line(Q_list: List[Tensor]):
1920
2592
  if q.dim() < 2:
1921
2593
  out.append((None, q))
1922
2594
  else:
1923
- out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
2595
+ out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
1924
2596
  return out
1925
2597
 
1926
2598
 
1927
- def _triu_shape(numel):
1928
- n = int((2 * numel) ** 0.5)
1929
- assert n * (n + 1) == 2 * numel
1930
- return n, n
1931
-
1932
-
1933
- @decorator
1934
- def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
2599
+ @decorator_knowngood
2600
+ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
1935
2601
  new = []
1936
2602
  for shape, q in Q_list:
1937
2603
  if shape is not None:
1938
- shape = _triu_shape(q.numel())
1939
- x = torch.zeros(shape, device=q.device, dtype=q.dtype)
1940
- x[tuple(torch.triu_indices(*shape, device=q.device))] = q
1941
- q = x
2604
+ x, y = torch.triu_indices(*shape, device=q.device)
2605
+ q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
2606
+ q_mat[x, y] = q
2607
+ if symmetric_output:
2608
+ q_mat[y, x] = q
2609
+ q = q_mat
1942
2610
  new.append(q)
1943
2611
  return new
1944
2612
 
1945
2613
 
1946
- def update_triu_(q_state, materialised):
1947
- for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
1948
- assert shape0 == shape1
1949
- copy_stochastic_(q, m)
1950
-
1951
-
1952
2614
  _warned = set()
1953
2615
 
1954
2616
 
@@ -1971,52 +2633,118 @@ def psgd_should_update(
1971
2633
  return int(group[name]) > int(cumulative_prob)
1972
2634
 
1973
2635
 
2636
+ @functools.lru_cache(maxsize=None)
2637
+ def cached_precond_grad_expr(Q_dim, grad_dim):
2638
+ expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
2639
+ expr = ",".join(expr)
2640
+ grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
2641
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
2642
+ return f"{expr},{grad_expr}->{out_expr}"
2643
+
2644
+
1974
2645
  @decorator_knowngood
1975
2646
  def precond_grad_cached_(
1976
- expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
2647
+ ea: Tensor,
2648
+ cached_q: List[Tensor],
2649
+ caution: bool = False,
2650
+ grad: Optional[Tensor] = None,
2651
+ cast: bool = True,
1977
2652
  ):
1978
2653
  if caution:
1979
2654
  ea = _compilable_cautioning(grad, ea)
1980
2655
  md = min_dtype(list(cached_q) + [ea])
1981
2656
  args = [q.to(md) for q in cached_q]
1982
2657
  args = args + [ea.to(md)]
1983
- new = torch.einsum(expr, *args)
2658
+ expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
2659
+ new = compiled_einsum(expr, *args)
1984
2660
  if cast:
1985
2661
  return new.to(ea.dtype)
1986
2662
  return new
1987
2663
 
1988
2664
 
2665
+ TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
2666
+
2667
+
1989
2668
  @decorator_knowngood
1990
- def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1991
- precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
2669
+ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
2670
+ precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
1992
2671
  update_param_(param, precond, lr, decay, caution=False)
1993
2672
 
1994
2673
 
1995
- def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1996
- lr = scalar_guard(lr, param[0])
1997
- _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
2674
+ def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
2675
+ lr, decay = scalar_guard(lr, decay, param[0])
2676
+ _compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
2677
+
2678
+
2679
+ @functools.lru_cache(maxsize=None)
2680
+ def precond_grad_expr(Q_dim, grad_dim):
2681
+ expr = [
2682
+ f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
2683
+ ]
2684
+ expr = ",".join(expr)
2685
+ grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
2686
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
2687
+ return f"{expr},{grad_expr}->{out_expr}"
1998
2688
 
1999
2689
 
2000
2690
  @decorator_knowngood
2001
- def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
2691
+ def psgd_precond_grad(
2692
+ ea: Tensor,
2693
+ preconds: TriuOrLine,
2694
+ caution: bool = False,
2695
+ grad: Optional[Tensor] = None,
2696
+ store_triu_as_line: bool = False,
2697
+ symmetric_output: bool = False,
2698
+ ):
2002
2699
  if caution:
2003
2700
  ea = _compilable_cautioning(grad, ea)
2701
+ if store_triu_as_line:
2702
+ preconds = line_to_triu(preconds, symmetric_output)
2004
2703
  md = min_dtype(list(preconds) + [ea])
2005
2704
  args = [q.to(md) for q in preconds]
2006
- args = args + args + [ea.to(md)]
2007
- new = torch.einsum(expr, *args)
2705
+ expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
2706
+ new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
2008
2707
  return new.to(ea.dtype)
2009
2708
 
2010
2709
 
2011
2710
  @decorator_knowngood
2012
- def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
2013
- precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
2711
+ def _compilable_fused_psgd_precond_grad(
2712
+ ea: Tensor,
2713
+ param,
2714
+ lr,
2715
+ grad,
2716
+ decay,
2717
+ caution,
2718
+ preconds: TriuOrLine,
2719
+ store_triu_as_line: bool = False,
2720
+ symmetric_output: bool = False,
2721
+ ):
2722
+ precond = psgd_precond_grad(
2723
+ ea,
2724
+ preconds,
2725
+ caution=caution,
2726
+ grad=grad,
2727
+ store_triu_as_line=store_triu_as_line,
2728
+ symmetric_output=symmetric_output,
2729
+ )
2014
2730
  update_param_(param, precond, lr, decay, caution=False, grad=grad)
2015
2731
 
2016
2732
 
2017
- def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
2018
- lr = scalar_guard(lr, param[0])
2019
- _compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
2733
+ def fused_psgd_precond_grad(
2734
+ ea: Tensor,
2735
+ param,
2736
+ lr,
2737
+ grad,
2738
+ decay,
2739
+ caution,
2740
+ preconds: TriuOrLine,
2741
+ store_triu_as_line: bool = False,
2742
+ symmetric_output: bool = False,
2743
+ ):
2744
+ lr, decay = scalar_guard(lr, decay, param[0])
2745
+ _compilable_fused_psgd_precond_grad(
2746
+ ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
2747
+ )
2020
2748
 
2021
2749
 
2022
2750
  @decorator_knowngood
@@ -2068,7 +2796,15 @@ def caution(g, update):
2068
2796
  return _compilable_cautioning(g, update)
2069
2797
 
2070
2798
 
2071
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_start=1000):
2799
+ def _inner_precond_update_prob_schedule(
2800
+ n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
2801
+ ):
2802
+ return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
2803
+
2804
+
2805
+ def precond_update_prob_schedule(
2806
+ max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
2807
+ ):
2072
2808
  """Anneal preconditioner update probability during beginning of training.
2073
2809
 
2074
2810
  PSGD benefits from more preconditioner updates at the beginning of training,
@@ -2079,11 +2815,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
2079
2815
  `min_prob` by ~4000 steps. Default settings work very well for most models and
2080
2816
  training regimes.
2081
2817
  """
2082
-
2083
- def _schedule(n):
2084
- return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
2085
-
2086
- return _schedule
2818
+ return functools.partial(
2819
+ _inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
2820
+ )
2087
2821
 
2088
2822
 
2089
2823
  def merge_group(group, *tensors):
@@ -2217,3 +2951,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
2217
2951
  def disable_caution_scaling():
2218
2952
  global _compilable_cautioning
2219
2953
  _compilable_cautioning = _compilable_caution_no_scale
2954
+
2955
+
2956
+ @decorator_knowngood
2957
+ def sam_step(parameters, ball_size, adaptive: bool = True):
2958
+ old_params = []
2959
+ for p in parameters:
2960
+ old_params.append(p.detach().clone())
2961
+ grad = promote(p.grad)
2962
+ if adaptive:
2963
+ grad = grad * promote(p).square()
2964
+ stochastic_add_(p.data, grad, ball_size)
2965
+ p.grad.zero_()
2966
+ return old_params