heavyball 1.7.1__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,28 +1,28 @@
1
+ import collections
1
2
  import contextlib
2
3
  import functools
3
4
  import gc
4
5
  import inspect
5
6
  import math
7
+ import pickle
6
8
  import random
7
9
  import re
8
10
  import string
9
11
  import warnings
10
- from typing import Callable, List, Optional, Tuple, Union
12
+ from typing import Callable, List, Literal, Optional, Tuple, Union
11
13
 
12
14
  import numpy as np
13
15
  import torch
14
16
  from torch import Tensor
15
- from torch._dynamo import config
16
17
  from torch._dynamo.exc import TorchDynamoException
17
18
  from torch.backends import cudnn, opt_einsum
18
19
  from torch.utils._pytree import tree_map
19
20
 
20
- config.cache_size_limit = 2**16
21
-
22
21
  compile_mode = "max-autotune-no-cudagraphs"
23
22
  dynamic = False
24
23
  compile_mode_recommended_to_none = None
25
- zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
24
+ zeroth_power_mode = "newtonschulz"
25
+ precise_zeroth_power_mode = "qr" # or svd
26
26
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
27
  _cudnn_double_backward_pattern = re.compile(
28
28
  r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
@@ -50,7 +50,7 @@ def decorator(func):
50
50
  return _fn
51
51
 
52
52
 
53
- def decorator_knowngood(func: Callable):
53
+ def decorator_knowngood(func: Callable, fullgraph: bool = True):
54
54
  compiled = None
55
55
 
56
56
  @functools.wraps(func)
@@ -59,7 +59,7 @@ def decorator_knowngood(func: Callable):
59
59
  return func(*args, **kwargs)
60
60
  nonlocal compiled
61
61
  if compiled is None:
62
- compiled = torch.compile(fullgraph=True, dynamic=dynamic, mode=compile_mode)(func)
62
+ compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
63
63
  return compiled(*args, **kwargs)
64
64
 
65
65
  return _fn
@@ -68,6 +68,16 @@ def decorator_knowngood(func: Callable):
68
68
  einsum_base = string.ascii_lowercase
69
69
 
70
70
 
71
+ @decorator_knowngood
72
+ def compiled_einsum(expr, *args):
73
+ """
74
+ this is necessary to avoid the slowdown introduced by uncompiled einsum
75
+ uncompiled einsum is twice as slow if we add three 1-sized dimensions
76
+ for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
77
+ """
78
+ return torch.einsum(expr, *args)
79
+
80
+
71
81
  @decorator_knowngood
72
82
  def _compilable_schedule_free_(
73
83
  p: List[Tensor],
@@ -122,6 +132,47 @@ def schedule_free_(
122
132
  return weight_sum
123
133
 
124
134
 
135
+ @decorator_knowngood
136
+ def _compilable_msam(
137
+ lr: Tensor,
138
+ beta1: Tensor,
139
+ param: List[Tensor],
140
+ z: List[Tensor],
141
+ update: List[Tensor],
142
+ grad: List[Tensor],
143
+ exp_avg: List[Tensor],
144
+ caution: bool,
145
+ decay: Tensor,
146
+ sam_step_size: Tensor,
147
+ ):
148
+ exp_avg32 = _lerp(exp_avg, update, beta1)
149
+ for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
150
+ u_ = u_.view_as(z_)
151
+ z32_ = promote(z_)
152
+ if caution:
153
+ u_ = _compilable_cautioning(promote(g_), u_)
154
+ z32_ = z32_ * (1 - decay * lr) + u_ * -lr
155
+ copy_stochastic_(z_, z32_)
156
+ copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
157
+
158
+
159
+ def msam_(
160
+ lr: float,
161
+ beta1: float,
162
+ param: List[Tensor],
163
+ z: List[Tensor],
164
+ update: List[Tensor],
165
+ grad: List[Tensor],
166
+ exp_avg: List[Tensor],
167
+ caution: bool,
168
+ weight_decay: float,
169
+ sam_step_size: float,
170
+ ):
171
+ param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
172
+ lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
173
+ _compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
174
+
175
+
125
176
  def append_or_extend(base, new):
126
177
  if isinstance(new, list):
127
178
  base.extend(new)
@@ -161,7 +212,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
161
212
  new_shape = [grad.shape[0], *new_shape[::-1]]
162
213
  new_grad = grad.reshape(new_shape)
163
214
  if not split:
164
- return new_grad
215
+ return new_grad.to(memory_format=torch.contiguous_format).contiguous()
165
216
 
166
217
  grads = [new_grad]
167
218
  for i, sh in reversed(list(enumerate(new_shape[:]))):
@@ -172,7 +223,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
172
223
  continue
173
224
  grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
174
225
  if len(grads) == 1:
175
- return new_grad
226
+ return new_grad.to(memory_format=torch.contiguous_format).contiguous()
176
227
  new_grads = []
177
228
  for g in grads:
178
229
  append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
@@ -279,16 +330,29 @@ def clean():
279
330
 
280
331
 
281
332
  def _ignore_warning(msg):
282
- warnings.filterwarnings("ignore", f".*{msg}.*")
333
+ warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
334
+
283
335
 
336
+ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
337
+ import opt_einsum as _opt_einsum
284
338
 
285
- def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
286
339
  cudnn.benchmark = True
287
340
  cudnn.deterministic = False
288
341
  cudnn.benchmark_limit = benchmark_limit
289
342
  torch.use_deterministic_algorithms(False)
290
343
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
291
- opt_einsum.set_flags(True, einsum_strategy)
344
+ opt_einsum.set_flags(True)
345
+ if einsum_strategy == "heavyball":
346
+ opt_einsum.strategy = "auto-hq"
347
+ choices = _opt_einsum.paths._AUTO_HQ_CHOICES
348
+ for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
349
+ if isinstance(fn, int):
350
+ fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
351
+ for i in range(max(choices.keys()), max_val):
352
+ if i not in choices:
353
+ choices[i] = fn
354
+ else:
355
+ opt_einsum.strategy = einsum_strategy
292
356
 
293
357
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
294
358
  _ignore_warning(
@@ -297,6 +361,9 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
297
361
  _ignore_warning(
298
362
  "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
299
363
  )
364
+ _ignore_warning(
365
+ "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
366
+ )
300
367
 
301
368
 
302
369
  @decorator
@@ -316,15 +383,6 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
316
383
  return X.to(G.dtype)
317
384
 
318
385
 
319
- def ortho(x):
320
- if zeroth_power_mode == "qr":
321
- return torch.linalg.qr(x).Q
322
- if zeroth_power_mode == "svd":
323
- u, _s, v = torch.linalg.svd(x)
324
- return u @ v.T
325
- raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
326
-
327
-
328
386
  @decorator_knowngood
329
387
  def _compilable_heavyball_momentum_(state, grad, beta):
330
388
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
@@ -377,7 +435,7 @@ def _compilable_grafting(magnitude, direction):
377
435
 
378
436
 
379
437
  @decorator_knowngood
380
- def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
438
+ def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
381
439
  if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
382
440
  y = zeropower_via_newtonschulz5(x, 5)
383
441
  elif mode == "qr":
@@ -395,9 +453,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
395
453
  y = _compilable_grafting(x, y)
396
454
  else:
397
455
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
456
+ if out is None:
457
+ return y
458
+
398
459
  set_(out, y)
399
460
 
400
461
 
462
+ def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
463
+ return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
464
+
465
+
401
466
  @decorator_knowngood
402
467
  def _compilable_scatter_set(target, source, index):
403
468
  target[:] = source.contiguous()[index].reshape_as(target)
@@ -413,6 +478,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
413
478
  :param Q: List of current eigenbases (updated in-place to Q_new).
414
479
  :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
415
480
  """
481
+ if exp_avg.dim() == 0: # preconditioning doesn't make sense here
482
+ Q.clear()
483
+ return
484
+
416
485
  if isinstance(Q, list) and not Q:
417
486
  return
418
487
 
@@ -430,10 +499,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
430
499
  q_old = promote(q.data)
431
500
 
432
501
  tmp = m @ q_old
433
- est_eig = torch.einsum("ij,ij->j", q_old, tmp)
502
+ est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
434
503
  sort_idx = torch.argsort(est_eig, descending=True)
435
504
 
436
- tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
505
+ tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
437
506
  new_qs.append(tmp)
438
507
 
439
508
  if exp_avg is None:
@@ -453,7 +522,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
453
522
  out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
454
523
 
455
524
  subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
456
- exp_avg_new = torch.einsum(
525
+ exp_avg_new = compiled_einsum(
457
526
  subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
458
527
  )
459
528
  copy_stochastic_(exp_avg, exp_avg_new)
@@ -568,6 +637,19 @@ def scalar_guard(*args):
568
637
  return out
569
638
 
570
639
 
640
+ def broadcastable_list_guard(*xs):
641
+ xs = list_guard(*xs)
642
+ for x in xs:
643
+ if isinstance(x[0], Tensor):
644
+ ref = x[0]
645
+ break
646
+ else:
647
+ raise ValueError("No tensor-valued input given")
648
+ xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
649
+ max_len = max(len(x) for x in xs)
650
+ return [x if len(x) > 1 else x * max_len for x in xs]
651
+
652
+
571
653
  @decorator_knowngood
572
654
  def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
573
655
  for x_, y_ in zip(x, y):
@@ -577,7 +659,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
577
659
 
578
660
 
579
661
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
580
- x, y = list_guard(x, y)
662
+ x, y = broadcastable_list_guard(x, y)
581
663
  alpha = scalar_guard(alpha, x[0])
582
664
  _compilable_stochastic_add_(x, y, alpha)
583
665
 
@@ -591,7 +673,7 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
591
673
 
592
674
 
593
675
  def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
594
- x, y = list_guard(x, y)
676
+ x, y = broadcastable_list_guard(x, y)
595
677
  alpha, divisor = scalar_guard(alpha, divisor, x[0])
596
678
  _compilable_stochastic_add_divide_(x, y, alpha, divisor)
597
679
 
@@ -605,7 +687,7 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
605
687
 
606
688
 
607
689
  def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
608
- x, y = list_guard(x, y)
690
+ x, y = broadcastable_list_guard(x, y)
609
691
  _compilable_stochastic_multiply_(x, y)
610
692
 
611
693
 
@@ -624,7 +706,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
624
706
  b = einsum_base[idx]
625
707
  g0 = einsum_base[: grad.dim()]
626
708
  g1 = g0.replace(b, b.upper())
627
- outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
709
+ outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
628
710
  stochastic_lerp_(m, outer_product, 1 - beta)
629
711
 
630
712
 
@@ -706,7 +788,7 @@ def project(grad, Q, back: bool):
706
788
  preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
707
789
  if preconditioners:
708
790
  out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
709
- out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
791
+ out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
710
792
  grad = out.to(grad.dtype)
711
793
  return grad
712
794
 
@@ -714,24 +796,28 @@ def project(grad, Q, back: bool):
714
796
  @contextlib.contextmanager
715
797
  def patch_backward():
716
798
  @contextlib.contextmanager
717
- def _inner(module):
799
+ def patch_module(module):
718
800
  original = module.backward
719
-
720
- signature = inspect.signature(original)
721
-
722
- def patched_backward(*args, **kwargs):
723
- new_kwargs = signature.bind(*args)
724
- new_kwargs.apply_defaults()
725
- new_kwargs = new_kwargs.arguments
726
- new_kwargs.update(kwargs)
727
- new_kwargs["create_graph"] = True
728
- return original(**new_kwargs)
729
-
730
- module.backward = patched_backward
731
- yield
732
- module.backward = original
733
-
734
- with _inner(torch.Tensor), _inner(torch.autograd):
801
+ try:
802
+ signature = inspect.signature(original)
803
+
804
+ @functools.wraps(original)
805
+ def patched_backward(*args, **kwargs):
806
+ new_kwargs = signature.bind(*args)
807
+ new_kwargs.apply_defaults()
808
+ new_kwargs = new_kwargs.arguments
809
+ new_kwargs.update(kwargs)
810
+ new_kwargs["create_graph"] = True
811
+ return original(**new_kwargs)
812
+
813
+ module.backward = patched_backward
814
+ yield
815
+ finally:
816
+ module.backward = original
817
+
818
+ with contextlib.ExitStack() as stack:
819
+ stack.enter_context(patch_module(torch.Tensor))
820
+ stack.enter_context(patch_module(torch.autograd))
735
821
  yield
736
822
 
737
823
 
@@ -743,6 +829,9 @@ class ExactHVPFailed(ValueError):
743
829
  pass
744
830
 
745
831
 
832
+ use_default = object()
833
+
834
+
746
835
  class StatefulOptimizer(torch.optim.Optimizer):
747
836
  """
748
837
  finite_differences saves memory, but needs more compute. (Alternative is true HVP)
@@ -755,7 +844,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
755
844
  compile_step: bool = False
756
845
  hessian_approx: bool = False
757
846
  precond_schedule: Union[Callable, float, None] = None
758
- stochastic_schedule: bool = False
847
+ stochastic_schedule: bool | Literal[use_default] = use_default
759
848
  finite_differences: bool = False
760
849
  fallback_to_finite_differences: bool = True
761
850
  _fallback_enabled: bool = False
@@ -765,18 +854,61 @@ class StatefulOptimizer(torch.optim.Optimizer):
765
854
  super().__init__(params, {**defaults, "foreach": foreach})
766
855
  self.use_ema = use_ema
767
856
  self.mapping = {}
768
- self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
769
- self._precond_rng = random.Random(0x12312)
857
+ self.mapping_inverse = {}
858
+
859
+ if self.stochastic_schedule is use_default:
860
+ stochastic_schedule = None
861
+ for group in self.param_groups:
862
+ new = group.get("stochastic_schedule", stochastic_schedule)
863
+ if stochastic_schedule is not None and new != stochastic_schedule:
864
+ raise ValueError("All parameter groups must have the same stochastic_schedule.")
865
+ stochastic_schedule = new
866
+ self.stochastic_schedule = stochastic_schedule
867
+
868
+ self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
869
+ self.precond_rng = random.Random(0x12312)
770
870
  self._is_preconditioning = None
771
871
 
772
872
  if self.hessian_approx and self.compile_step:
773
873
  raise ValueError("Hessian approximation can't be used with compile_step.")
774
874
 
875
+ self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
876
+ self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
877
+ self._init_mapping()
878
+
879
+ def _store_stats(self, state_dict: dict[str, any]):
880
+ state_dict["heavyball"] = {
881
+ "inner_group": self.inner_group,
882
+ "precond_rng": pickle.dumps(self.precond_rng),
883
+ "use_ema": self.use_ema,
884
+ "ema_decay": self.ema_decay,
885
+ "compile_step": self.compile_step,
886
+ "hessian_approx": self.hessian_approx,
887
+ "precond_schedule": pickle.dumps(self.precond_schedule),
888
+ "stochastic_schedule": self.stochastic_schedule,
889
+ "fallback_to_finite_differences": self.fallback_to_finite_differences,
890
+ "_fallback_enabled": self._fallback_enabled,
891
+ "hvp_interval": self.hvp_interval,
892
+ }
893
+
894
+ def _load_stats(self, state_dict):
895
+ sd = state_dict.pop("heavyball", {})
896
+ for k, v in sd.items():
897
+ if k in ("precond_rng", "precond_schedule"):
898
+ v = pickle.loads(v)
899
+ setattr(self, k, v)
900
+
775
901
  def get_groups(self, group):
776
902
  return [group]
777
903
 
778
- def state_(self, arg: Tensor):
779
- return self.state[arg]
904
+ @functools.lru_cache(maxsize=None)
905
+ def state_(self, arg: Tensor, fail: bool = True):
906
+ if not fail and arg not in self.mapping:
907
+ return {}
908
+ state_param, index = self.mapping_inverse[arg]
909
+ if state_param not in self.state:
910
+ self.state[state_param] = collections.defaultdict(dict)
911
+ return self.state[state_param][index]
780
912
 
781
913
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
782
914
  for p, g in zip(p_list, g_list):
@@ -786,6 +918,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
786
918
  old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
787
919
  mars_correction(g_list, old_gs, mars_gamma, beta)
788
920
 
921
+ def _init_mapping(self, group: dict | None = None):
922
+ if group is None:
923
+ for group in self.param_groups:
924
+ self._init_mapping(group)
925
+ return
926
+
927
+ for p in group["params"]:
928
+ if p not in self.mapping:
929
+ self.mapping[p] = p_views = merge_group(group, p)
930
+ for i, pv in enumerate(p_views):
931
+ self.mapping_inverse[pv] = (p, i)
932
+
789
933
  def split_p_and_g_in_group(
790
934
  self,
791
935
  group: dict,
@@ -809,6 +953,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
809
953
  p_views = self.mapping[p]
810
954
  else:
811
955
  self.mapping[p] = p_views = merge_group(group, p)
956
+ for i, pv in enumerate(p_views):
957
+ self.mapping_inverse[pv] = (p, i)
812
958
 
813
959
  vector = getattr(p, "vector", None)
814
960
  hessian_vector = getattr(p, "hessian_vector", None)
@@ -957,8 +1103,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
957
1103
  raise ValueError("Hessian approximation requires a closure.")
958
1104
  return None
959
1105
 
960
- step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
961
- if not hessian_approx or step % self.hvp_interval == 0:
1106
+ step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
1107
+ if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
962
1108
  with torch.enable_grad():
963
1109
  loss = closure()
964
1110
  return loss
@@ -997,12 +1143,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
997
1143
  if self.precond_schedule is None:
998
1144
  self._is_preconditioning = False
999
1145
  else:
1000
- self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
1146
+ self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
1001
1147
  loss = self._handle_closure(closure)
1002
1148
 
1003
1149
  # we assume that parameters are constant and that there are no excessive recompiles
1004
1150
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
1005
1151
  for group in self.param_groups:
1152
+ if "param_count" not in group:
1153
+ group["param_count"] = sum(p.numel() for p in group["params"])
1006
1154
  group["is_preconditioning"] = self._is_preconditioning
1007
1155
  self._step(group)
1008
1156
  if self.use_ema:
@@ -1306,74 +1454,115 @@ def stable_exp(x: Tensor):
1306
1454
  return torch.where(x > 0, 1 / (-x).exp(), x.exp())
1307
1455
 
1308
1456
 
1457
+ def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
1458
+ # ln(mean(x ** pow) ** (1 / pow / 2))
1459
+ normalization = math.log(x.numel())
1460
+ x = x.double()
1461
+ x = x.abs()
1462
+ x = x.clamp(min=eps)
1463
+ x = x.log()
1464
+ x = x * pow
1465
+ x = x.flatten()
1466
+ x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
1467
+ x = x - normalization # sum -> mean (divide by x.numel() in log space)
1468
+ return x / pow / 2
1469
+
1470
+
1309
1471
  @decorator_knowngood
1310
1472
  def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
1311
1473
  # 1 / (mean(x ** pow) ** (1 / pow / 2))
1312
- log_x = x.double().abs().clamp(min=eps).log()
1313
- log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
1314
- return stable_exp(-log_mean_x_pow / pow / 2)
1474
+ return stable_exp(-_lse_mean(x, pow, eps))
1315
1475
 
1316
1476
 
1317
1477
  @decorator_knowngood
1318
1478
  def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
1319
1479
  # mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
1320
- log_x = x.double().abs().clamp(min=eps).log()
1321
- log_y = y.double().abs().clamp(min=eps).log()
1322
-
1323
- x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
1324
- x_normed = x_normed / pow0 / 2
1480
+ return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
1325
1481
 
1326
- y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
1327
- y_normed = y_normed / pow1 / 2
1328
1482
 
1329
- return stable_exp(x_normed - y_normed)
1483
+ class PrecondInitError(ValueError):
1484
+ pass
1330
1485
 
1331
1486
 
1332
- def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1487
+ def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
1333
1488
  automatic_scale = True
1334
1489
  manual_hint = " Set it manually using `precond_init_scale=0.1`"
1490
+ scale_scale = 1 if scale_scale is None else scale_scale
1491
+
1335
1492
  if scale is not None:
1336
1493
  automatic_scale = False
1337
1494
  warn_once(
1338
1495
  "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
1339
1496
  )
1340
- if scale_scale is not None and scale_scale != 1:
1497
+ if scale_scale != 1:
1341
1498
  warn_once(
1342
- "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
1499
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
1500
+ )
1501
+ if scale_power is not None:
1502
+ warn_once(
1503
+ "precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
1343
1504
  )
1344
1505
  elif hessian_vector is None:
1345
1506
  scale = mean_root(grad, 4) * scale_scale
1346
1507
  else:
1347
1508
  scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1509
+
1510
+ if automatic_scale:
1511
+ scale_power = 0.5 if scale_power is None else scale_power
1512
+ scale = scale**scale_power
1513
+
1348
1514
  if isinstance(scale, torch.Tensor):
1349
1515
  scale = scale.item() # slow, but necessary
1516
+
1350
1517
  if np.isfinite(scale):
1351
- if scale > scale_max or scale < 1 / scale_max:
1518
+ if scale > scale_max: # fallthrough to later checks
1352
1519
  warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1353
- return scale
1520
+ else:
1521
+ return scale
1522
+
1354
1523
  if not automatic_scale:
1355
- raise ValueError("The manually set precond_init_scale is not finite")
1524
+ raise PrecondInitError("The manually set precond_init_scale is not finite")
1356
1525
 
1357
1526
  for x in (grad, hessian_vector, vector):
1358
1527
  if x is None:
1359
1528
  continue
1360
- if torch.allclose(x, torch.zeros_like(x)).item():
1361
- raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1529
+ if torch.allclose(x, torch.zeros_like(x)):
1530
+ raise PrecondInitError(
1531
+ f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
1532
+ )
1362
1533
  if not torch.isfinite(x).all().item():
1363
- raise ValueError("Grad or HVP is not finite")
1364
- raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1534
+ raise PrecondInitError("Grad or HVP is not finite")
1535
+
1536
+ if np.isfinite(scale):
1537
+ return scale
1538
+
1539
+ raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
1365
1540
 
1366
1541
 
1367
- def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
1368
- scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1369
- U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1370
- V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1542
+ def init_lra(
1543
+ grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
1544
+ ):
1545
+ # "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
1546
+ # https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
1547
+ scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
1548
+ uv_scale = (param_count * (rank + eps)) ** -0.5
1549
+ U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
1550
+ V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
1371
1551
  d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
1372
1552
  return U, V, d
1373
1553
 
1374
1554
 
1375
1555
  def init_Q_exprs(
1376
- grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
1556
+ grad,
1557
+ scale,
1558
+ scale_scale,
1559
+ scale_power,
1560
+ max_size,
1561
+ min_ndim_triangular,
1562
+ memory_save_mode,
1563
+ hessian_vector,
1564
+ vector,
1565
+ dtype=None,
1377
1566
  ):
1378
1567
  """
1379
1568
  For a scalar or tensor `grad`, we initialize its preconditioner Q and
@@ -1382,21 +1571,13 @@ def init_Q_exprs(
1382
1571
  precond init scale computation from
1383
1572
  https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
1384
1573
  """
1385
- scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1386
- letters = string.ascii_lowercase + string.ascii_uppercase
1574
+ scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
1387
1575
  dtype = dtype if dtype is not None else grad.dtype
1388
1576
  shape = grad.shape
1389
1577
 
1390
1578
  if len(shape) == 0: # scalar
1391
1579
  Q = [scale * torch.ones_like(grad, dtype=dtype)]
1392
- exprA = ",->"
1393
- exprGs = [",->"]
1394
- exprP = ",,->"
1395
- return [Q, (exprA, tuple(exprGs), exprP)]
1396
-
1397
- # Tensor
1398
- if len(shape) > 13:
1399
- raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
1580
+ return Q
1400
1581
 
1401
1582
  scale = scale ** (1 / len(shape))
1402
1583
 
@@ -1409,6 +1590,9 @@ def init_Q_exprs(
1409
1590
  sorted_shape = sorted(shape)
1410
1591
  if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
1411
1592
  dim_diag[_max_idx(shape)] = True
1593
+ elif memory_save_mode == "one_triu":
1594
+ shape_ranks = np.argsort(np.argsort(shape)) # ranks
1595
+ dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
1412
1596
  elif memory_save_mode == "all_diag":
1413
1597
  dim_diag = [True for _ in shape]
1414
1598
  else:
@@ -1418,66 +1602,90 @@ def init_Q_exprs(
1418
1602
  )
1419
1603
 
1420
1604
  Q = []
1421
- piece1A, piece2A, piece3A = ([], "", "")
1422
- exprGs = []
1423
- piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
1424
1605
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
1425
1606
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
1426
1607
  # use diagonal matrix as preconditioner for this dim
1427
1608
  Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
1428
-
1429
- piece1A.append(letters[i])
1430
- piece2A = piece2A + letters[i]
1431
- piece3A = piece3A + letters[i]
1432
- piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1433
- subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
1434
- exprGs.append(subscripts)
1435
- piece1P.append(letters[i + 13])
1436
- piece2P.append(letters[i + 13])
1437
- piece3P = piece3P + letters[i + 13]
1438
- piece4P = piece4P + letters[i + 13]
1439
1609
  else:
1440
1610
  # use triangular matrix as preconditioner for this dim
1441
1611
  Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
1442
- piece1A.append(letters[i] + letters[i + 13])
1443
- piece2A = piece2A + letters[i + 13]
1444
- piece3A = piece3A + letters[i]
1445
- piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1446
- piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1447
- subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
1448
- exprGs.append(subscripts)
1449
- a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1450
- piece1P.append(a + b)
1451
- piece2P.append(a + c)
1452
- piece3P = piece3P + c
1453
- piece4P = piece4P + b
1454
-
1455
- exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
1456
- exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
1457
- return [Q, (exprA, tuple(exprGs), exprP)]
1612
+ return Q
1458
1613
 
1459
1614
 
1460
- @decorator
1461
- def psgd_balance_Q(Q_in):
1462
- norms = torch.stack([q.norm(float("inf")) for q in Q_in])
1463
- geometric_mean = norms.log().mean().exp()
1464
- norms = geometric_mean / norms
1465
- torch._foreach_mul_(Q_in, list(norms))
1615
+ @decorator_knowngood
1616
+ def psgd_balance_Q(Q):
1617
+ norms = [promote(q.norm(float("inf"))).log() for q in Q]
1618
+ geometric_mean = sum([n for n in norms]) / len(Q)
1619
+ for q, n in zip(Q, norms):
1620
+ q *= (geometric_mean - n).exp()
1466
1621
 
1467
1622
 
1468
- @decorator
1469
- def psgd_balance_lra(U: Tensor, V: Tensor):
1470
- u_norm = promote(torch.linalg.vector_norm(U))
1471
- v_norm = promote(torch.linalg.vector_norm(V))
1472
- scale = (u_norm / v_norm) ** 0.5
1473
- U.div_(scale)
1474
- V.mul_(scale)
1623
+ @decorator_knowngood
1624
+ def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
1625
+ u_norm = sum(u.square().sum().double() for u in U)
1626
+ v_norm = sum(v.square().sum().double() for v in V)
1627
+ scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
1628
+ scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
1629
+ stochastic_multiply_(U, [1 / scale] * len(U))
1630
+ stochastic_multiply_(V, [scale] * len(V))
1631
+ return multi_flatten((U, 1), (V, 1), (d, 0))
1475
1632
 
1476
1633
 
1477
1634
  @decorator
1478
1635
  def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1479
1636
  dtype = min_dtype([U, V, x])
1480
- return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1637
+ return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1638
+
1639
+
1640
+ @decorator_knowngood
1641
+ def _compilable_d_step(
1642
+ d: Tensor,
1643
+ d_orig: List[Tensor],
1644
+ invQtv: Tensor,
1645
+ vector: Tensor,
1646
+ inverse_precond_vector: Tensor,
1647
+ hessian_vector: Tensor,
1648
+ precond_hessian_vector: Tensor,
1649
+ eps: Tensor,
1650
+ step: Tensor,
1651
+ delayed: bool,
1652
+ ):
1653
+ precond_hessian_vector = promote(precond_hessian_vector)
1654
+ hessian_vector = promote(hessian_vector)
1655
+ vector = promote(vector)
1656
+ inverse_precond_vector = promote(inverse_precond_vector)
1657
+ invQtv = promote(invQtv)
1658
+ inverse_precond_vector = invQtv - inverse_precond_vector
1659
+
1660
+ nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
1661
+
1662
+ """
1663
+ 1) Sketching
1664
+ 1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
1665
+ 1.2) reduced-precision selection of largest element (halves memory traffic)
1666
+ 2) Computation
1667
+ 2.1) select relevant indices
1668
+ 2.2) redo 1.1 in double precision for scalar values
1669
+ 2.3) return high-precision normalized step-size
1670
+ overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
1671
+ improving precision
1672
+ """
1673
+ a0 = promote(d) * precond_hessian_vector
1674
+ a1 = vector
1675
+ b0 = inverse_precond_vector / promote(d)
1676
+ b1 = hessian_vector
1677
+
1678
+ divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
1679
+ idx = divisor.bfloat16().flatten().argmax()
1680
+ a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
1681
+ b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
1682
+ divisor = (a * b).sqrt().clamp(min=eps)
1683
+ step = -step / divisor
1684
+
1685
+ # fused update(s)
1686
+ apply_flat_add(d_orig, nablaD, step)
1687
+ if not delayed:
1688
+ copy_stochastic_(d, promote(d) - nablaD * step)
1481
1689
 
1482
1690
 
1483
1691
  def update_lra_precond_(
@@ -1489,13 +1697,14 @@ def update_lra_precond_(
1489
1697
  eps: float,
1490
1698
  step: float,
1491
1699
  delayed: bool,
1700
+ precond_u: bool,
1492
1701
  ):
1493
1702
  """
1494
1703
  Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1495
1704
  """
1496
1705
  U_orig, V_orig, d_orig = U, V, d
1497
1706
 
1498
- U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
1707
+ U, V, d = _lra_flatten_and_balance(U, V, d)
1499
1708
 
1500
1709
  dtype = min_dtype([U, V, vector, hessian_vector])
1501
1710
  U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
@@ -1503,10 +1712,10 @@ def update_lra_precond_(
1503
1712
  eps = scalar_guard(eps, vector)
1504
1713
 
1505
1714
  Qh = low_rank_mm(U, V, d * hessian_vector)
1506
- Ph = d * low_rank_mm(V, U, Qh)
1715
+ Ph = low_rank_mm(V, U, Qh)
1507
1716
  rank = U.size(1)
1508
1717
 
1509
- VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
1718
+ VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
1510
1719
  I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
1511
1720
  IpVtU = I + VtU
1512
1721
  invQtv = vector / d
@@ -1524,47 +1733,39 @@ def update_lra_precond_(
1524
1733
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1525
1734
 
1526
1735
  invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1527
- invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1528
- invPv = invPv / d
1736
+ invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1529
1737
 
1530
- nablaD = Ph * hessian_vector - vector * invPv
1531
- divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
1532
- divisor = divisor.add(eps).sqrt().max()
1533
- d_step = step / divisor
1534
-
1535
- apply_flat_add(d_orig, d * nablaD, -d_step)
1738
+ eps, step = scalar_guard(eps, step, vector)
1739
+ _compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
1536
1740
 
1537
1741
  a, b = Qh, invQtv
1538
1742
 
1539
- precond_u = random.random() < 0.5 # update either U or V, not both at the same time
1540
1743
  precond = V if precond_u else U
1541
- atV = torch.einsum("b,br->r", a, precond) # o == one
1542
- btV = torch.einsum("b,br->r", b, precond)
1543
- atVVt = torch.einsum("r,br->b", atV, precond)
1544
- btVVt = torch.einsum("r,br->b", btV, precond)
1545
- precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
1744
+ atV = compiled_einsum("b,br->r", a, precond) # o == one
1745
+ btV = compiled_einsum("b,br->r", b, precond)
1746
+ atVVt = compiled_einsum("r,br->b", atV, precond)
1747
+ btVVt = compiled_einsum("r,br->b", btV, precond)
1748
+ precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
1546
1749
  if precond_u:
1547
- a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
1548
- b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
1750
+ a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
1751
+ b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
1549
1752
  else:
1550
- a = a + torch.einsum("br,r->b", V, atV)
1551
- b = b + torch.einsum("br,r->b", V, btV)
1552
- a = torch.einsum("b,r->br", a, atV)
1553
- b = torch.einsum("b,r->br", b, btV)
1753
+ a = a + compiled_einsum("br,r->b", V, atV)
1754
+ b = b + compiled_einsum("br,r->b", V, btV)
1755
+ a = compiled_einsum("b,r->br", a, atV)
1756
+ b = compiled_einsum("b,r->br", b, btV)
1554
1757
  apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
1555
-
1556
1758
  if not delayed:
1557
- stochastic_add_([d], [d * nablaD], -d_step)
1558
1759
  stochastic_add_([U if precond_u else V], [b - a], precond_step)
1559
1760
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1560
1761
 
1561
1762
 
1562
- def lra_precond(U, V, d, g):
1763
+ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
1563
1764
  """
1564
1765
  As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
1565
1766
  """
1566
- g = low_rank_mm(U, V, d * g)
1567
- return d * low_rank_mm(V, U, g)
1767
+ new_g = low_rank_mm(U, V, d * g)
1768
+ return d * low_rank_mm(V, U, new_g)
1568
1769
 
1569
1770
 
1570
1771
  @decorator_knowngood
@@ -1575,16 +1776,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
1575
1776
 
1576
1777
 
1577
1778
  @decorator_knowngood
1578
- def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
1579
- update = lra_precond(U, V, d, update)
1779
+ def _compilable_lra_update_(
1780
+ params: List[Tensor],
1781
+ update: List[Tensor],
1782
+ U: Tensor,
1783
+ V: Tensor,
1784
+ d: Tensor,
1785
+ lr: Tensor,
1786
+ decay: Tensor,
1787
+ caution: bool,
1788
+ grads: List[Tensor],
1789
+ ):
1790
+ update = lra_precond(U, V, d, flatten(update))
1580
1791
  start = 0
1581
1792
  update = update.flatten()
1582
- for p in params:
1793
+ for p, g in zip(params, grads):
1583
1794
  size = p.numel()
1584
- copy_stochastic_(p, update[start : start + size].view_as(p))
1795
+ update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
1585
1796
  start += size
1586
1797
 
1587
1798
 
1799
+ def apply_lra_update(
1800
+ params: List[Tensor],
1801
+ update: Tensor,
1802
+ U: Tensor,
1803
+ V: Tensor,
1804
+ d: Tensor,
1805
+ lr: float,
1806
+ decay: float,
1807
+ caution: bool,
1808
+ grads: List[Tensor],
1809
+ ):
1810
+ params, grads = list_guard(params, grads)
1811
+ lr, decay = scalar_guard(lr, decay, params[0])
1812
+ _compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
1813
+
1814
+
1588
1815
  @decorator_knowngood
1589
1816
  def apply_flat_update(params: List[Tensor], update: Tensor):
1590
1817
  start = 0
@@ -1595,6 +1822,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
1595
1822
  start += size
1596
1823
 
1597
1824
 
1825
+ @decorator_knowngood
1826
+ def zero_(x: List[Tensor]):
1827
+ for i in x:
1828
+ i.zero_()
1829
+
1830
+
1598
1831
  @decorator_knowngood
1599
1832
  def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1600
1833
  start = 0
@@ -1620,7 +1853,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1620
1853
  @decorator_knowngood
1621
1854
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1622
1855
  last_dim = x[0].shape[-remaining:] if remaining else []
1623
- return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
1856
+ return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
1857
+
1858
+
1859
+ @decorator_knowngood
1860
+ def multi_flatten(*xs: Tuple[List[Tensor], int]):
1861
+ return [flatten(x, i) for x, i in xs]
1624
1862
 
1625
1863
 
1626
1864
  @decorator_knowngood
@@ -1634,68 +1872,277 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1634
1872
  return flatten(vs), flatten(gs)
1635
1873
 
1636
1874
 
1637
- @decorator_knowngood
1638
1875
  def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1639
1876
  md = min_dtype(args)
1640
- return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1877
+ return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1641
1878
 
1642
1879
 
1643
- def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1644
- order = G.dim()
1645
- if order > 1:
1646
- conjB = conjB.view_as(G).permute(*range(1, order), 0)
1647
- conjB = conjB.to(promote(G.dtype))
1648
- A = casted_einsum(exprA, *Q, G)
1649
- for i, q in enumerate(Q):
1880
+ @decorator_knowngood
1881
+ def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
1882
+ triangular_qs = []
1883
+ conjB = promote(conjB)
1884
+ for i, q in enumerate(Qs):
1650
1885
  q = promote(q)
1651
1886
  if q.dim() <= 1:
1652
- conjB /= q
1887
+ if conjB.ndim == 0:
1888
+ conjB = conjB / q
1889
+ else:
1890
+ shape = [1] * conjB.ndim
1891
+ shape[i] = -1
1892
+ conjB = conjB / q.view(shape)
1653
1893
  else:
1654
- solved = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)).contiguous(), upper=True, left=False)
1655
- conjB = solved.reshape_as(conjB)
1656
- if i < order - 1:
1657
- conjB = conjB.transpose(i, -1)
1894
+ triangular_qs.append((i, q))
1895
+ return triangular_qs, conjB
1896
+
1897
+
1898
+ @decorator_knowngood
1899
+ def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
1900
+ solved = solved.reshape(transposed_shape)
1901
+ solved = solved.transpose(-1, last_dim)
1902
+ solved = solved.reshape(original_shape)
1903
+ solved = solved.transpose(-1, new_dim)
1904
+ return solved.contiguous(), solved.shape
1905
+
1906
+
1907
+ def ndim_tuple(Q: list[Tensor]) -> tuple:
1908
+ return tuple(q.ndim for q in Q)
1909
+
1910
+
1911
+ def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1912
+ exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
1913
+ A = casted_einsum(exprA, *Q, G)
1914
+ solve = torch.compiler.disable(torch.linalg.solve_triangular)
1915
+ transposed_shape = original_shape = conjB.shape
1916
+ prev_i = -1
1917
+ qs, conjB = _psgd_calc_scalars_(Q, conjB)
1918
+ for i, tri_q in qs:
1919
+ conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
1920
+ prev_i = i
1921
+ conjB = solve(tri_q, conjB, upper=True, left=False)
1922
+ conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
1658
1923
  return A, conjB
1659
1924
 
1660
1925
 
1661
- def psgd_lb(A, max_abs):
1662
- A /= max_abs
1663
- a0 = torch.einsum("ij,ij->j", A, A)
1664
- i = torch.argmax(a0)
1665
- x = torch.index_select(A, 1, i).flatten().contiguous()
1666
- x = torch.einsum("i,ij->j", x, A)
1667
- x /= x.norm()
1668
- x = torch.einsum("j,kj->k", x, A)
1669
- x = x.norm()
1670
- x *= max_abs
1671
- return x
1926
+ @decorator_knowngood
1927
+ def _random_projection(x: Tensor, scale: Optional[Tensor]):
1928
+ if scale is None:
1929
+ scale = x.norm(float("inf")).clamp(min=1e-8)
1930
+ k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
1931
+ norm = x.square().sum(0)
1932
+ indices = torch.topk(norm, k, largest=True).indices
1933
+ return x.index_select(1, indices).contiguous() / scale, scale
1934
+
1935
+
1936
+ def max_singular_value_exact(A, use_lobpcg: bool = False):
1937
+ try:
1938
+ if use_lobpcg:
1939
+ A = A @ A.T
1940
+ eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
1941
+ return eigval[0].sqrt()
1942
+ else:
1943
+ return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
1944
+ except torch.linalg.LinAlgError:
1945
+ return torch.zeros((), device=A.device, dtype=A.dtype)
1946
+
1947
+
1948
+ @decorator_knowngood
1949
+ def max_singular_value_power_iter(A: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
1950
+ """
1951
+ Rayleigh quotient of row with the largest norm + optional power iterations
1952
+ """
1953
+ x_norm, max_idx = A.norm(dim=1).max(dim=0)
1954
+ x = A.index_select(0, max_idx).flatten().contiguous()
1955
+ A = A / x_norm
1956
+ x = x / x_norm
1957
+ for _ in range(iterations):
1958
+ x = A.T.mv(A.mv(x)) # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
1959
+ x = x / x.norm()
1960
+ return (x @ A.T.mv(A.mv(x))).sqrt() * x_norm
1961
+
1962
+
1963
+ @decorator_knowngood
1964
+ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
1965
+ """
1966
+ Adapted from @evanatyourservice
1967
+ """
1968
+ Y, max_abs = _random_projection(A, max_abs)
1969
+ Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
1970
+ Q = Q / max_abs
1971
+ Z = A.T @ Q
1972
+ W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
1973
+ sketch_norm = max_singular_value_exact(Z.T @ W)
1974
+ return sketch_norm * max_abs
1975
+
1976
+
1977
+ @decorator_knowngood
1978
+ def max_singular_value(
1979
+ A: Tensor, max_abs: Optional[Tensor], max_svd: int = 32, use_cholesky: bool = False, power_iter: int = 0
1980
+ ) -> Tensor:
1981
+ if min(A.shape) <= max_svd:
1982
+ return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
1983
+ if use_cholesky or power_iter < 0:
1984
+ return max_singular_value_cholesky(A, max_abs)
1985
+ return max_singular_value_power_iter(A, None, iterations=power_iter)
1986
+
1987
+
1988
+ @decorator_knowngood
1989
+ def _psgd_default_preconditioner_grad(
1990
+ terms: List[Tuple[Tensor, Tensor]],
1991
+ Q: List[Tensor],
1992
+ ) -> List[Tensor]:
1993
+ out = []
1994
+ for q, (x, y) in zip(Q, terms):
1995
+ x = promote(x)
1996
+ y = promote(y)
1997
+ update = x - y
1998
+ if q.ndim < 2:
1999
+ update = q * update
2000
+ else:
2001
+ update = (q @ update).triu()
2002
+ out.append(update)
2003
+ return out
2004
+
2005
+
2006
+ @decorator_knowngood
2007
+ def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
2008
+ if isinstance(Q[0], tuple):
2009
+ psgd_balance_Q([o[1] for o in Q])
2010
+ return line_to_triu(Q, symmetric_output)
2011
+ psgd_balance_Q(Q)
2012
+ return Q
2013
+
2014
+
2015
+ @functools.lru_cache(maxsize=None)
2016
+ def calcG_expr(q_dim, g_dim):
2017
+ exprs = []
2018
+ base = einsum_base[:g_dim]
2019
+ for i, q in enumerate(q_dim):
2020
+ new = list(base)
2021
+ if q == 2:
2022
+ new[i] = "Z"
2023
+ out = f"{base[i]}Z"
2024
+ else:
2025
+ out = base[i]
2026
+ exprs.append(f"{base},{''.join(new)}->{out}")
2027
+ return exprs
1672
2028
 
1673
2029
 
1674
2030
  @decorator
1675
- def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
2031
+ def psgd_update_precond(
2032
+ G: Tensor,
2033
+ precond_lr: float,
2034
+ oq: "TriuOrLine",
2035
+ store_triu_as_line: bool,
2036
+ velocity: Optional[List[Tensor]],
2037
+ beta2: float,
2038
+ ortho_method: Optional[str],
2039
+ V: Tensor,
2040
+ running_lower_bound: List[Tensor],
2041
+ lower_bount_beta: float,
2042
+ power_iter: int,
2043
+ ) -> None:
1676
2044
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1677
- exprA, exprGs, _ = exprs
1678
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1679
-
1680
- for q, exprG, o in zip(Q, exprGs, oq):
1681
- term1 = promote(torch.einsum(exprG, A, A))
1682
- term2 = promote(torch.einsum(exprG, conjB, conjB))
1683
- term1, term2 = term1 - term2, term1 + term2
1684
- term1 *= precond_lr
1685
- norm = term2.norm(float("inf"))
1686
- if q.dim() < 2:
1687
- term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
2045
+ Q = _balance_to_triu(oq)
2046
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2047
+ precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
2048
+
2049
+ A, conjB = psgd_calc_A_and_conjB(G, Q, V)
2050
+ terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
2051
+ del A, conjB, V
2052
+ updates = _psgd_default_preconditioner_grad(terms, Q)
2053
+ _psgd_precond_update_(
2054
+ updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2055
+ )
2056
+ return None
2057
+
2058
+
2059
+ @decorator_knowngood
2060
+ def _psgd_precond_update_(
2061
+ matmuled: List[Optional[Tensor]],
2062
+ Q: "TriuOrLine",
2063
+ running_lower_bound: List[Tensor],
2064
+ lower_bount_beta: Tensor,
2065
+ precond_lr: Tensor,
2066
+ store_triu_as_line: bool,
2067
+ power_iter: int,
2068
+ ):
2069
+ for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
2070
+ if isinstance(oq, tuple):
2071
+ oq = oq[1]
2072
+
2073
+ q = promote(oq)
2074
+ if update.ndim < 2:
2075
+ lb = update.norm(float("inf"))
1688
2076
  else:
1689
- torch.triu(term1, out=term1)
1690
- term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1691
- term1 = torch.mm(term1, q.to(term1.dtype))
1692
- if store_triu_as_line:
1693
- term1 = triu_to_line([term1])[0][1] # Convert update to line format
1694
- # Apply update directly to the tensor part of the state tuple o[1]
1695
- stochastic_add_(o[1], term1, -1)
2077
+ lb = max_singular_value(update, None, power_iter=power_iter)
2078
+ update = promote(update)
2079
+ if store_triu_as_line:
2080
+ update = triu_to_line([update])[0][1]
2081
+
2082
+ lb = promote(lb)
2083
+ lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
2084
+ copy_stochastic_(lb_state, lb)
2085
+ copy_stochastic_(oq, q - update / lb * precond_lr)
2086
+
2087
+
2088
+ @decorator_knowngood
2089
+ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
2090
+ """
2091
+ I: Identity
2092
+ U: Update / gg / target
2093
+ Q: q, preconditioner
2094
+ scale: scalar scale
2095
+ ---
2096
+ U = T * scale - I
2097
+ F = I - U # = 2I - U * scale
2098
+ O = F @ Q @ F - Q
2099
+ """
2100
+ out = []
2101
+ for gg, q in zip(GG, Q):
2102
+ if gg.ndim < 2:
2103
+ scale = max(1, gg.numel()) / numel
2104
+ target = promote(gg)
2105
+ update = target * scale - 1
2106
+ out.append(q - (1 - update) * q * (1 - update))
1696
2107
  else:
1697
- # Apply update to the state tensor o
1698
- stochastic_add_(o, term1, -1)
2108
+ scale = gg.size(0) / numel
2109
+ gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
2110
+ update = q - gg @ q @ gg
2111
+ out.append(update + update.T) # make matrix symmetric
2112
+ return out
2113
+
2114
+
2115
+ @decorator
2116
+ def inverse_free_psgd_update_precond(
2117
+ G: Tensor,
2118
+ precond_lr: float,
2119
+ oq: List[Tensor],
2120
+ store_triu_as_line: bool,
2121
+ velocity: Optional[List[Tensor]],
2122
+ beta2: float,
2123
+ ortho_method: Optional[str],
2124
+ V: None,
2125
+ running_lower_bound: List[Tensor],
2126
+ lower_bount_beta: float,
2127
+ power_iter: int,
2128
+ ) -> Tensor:
2129
+ """Update Kronecker product preconditioner Q with pair (V, G)."""
2130
+ assert V is None
2131
+ assert ortho_method is None
2132
+ assert velocity is None
2133
+ del V, ortho_method, velocity
2134
+
2135
+ Q = _balance_to_triu(oq, True)
2136
+ precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
2137
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2138
+
2139
+ G = psgd_precond_grad(G, Q)
2140
+ terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
2141
+ matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
2142
+ _psgd_precond_update_(
2143
+ matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2144
+ )
2145
+ return G
1699
2146
 
1700
2147
 
1701
2148
  @decorator_knowngood
@@ -1732,6 +2179,34 @@ def rmsnorm_clip_(x, clip_at: float = 1.0):
1732
2179
  return _compilable_rmsnorm_clip_(x, clip_at)
1733
2180
 
1734
2181
 
2182
+ @decorator_knowngood
2183
+ def _compilable_global_rmsnorm_clip_(x, clip_at):
2184
+ x = list(map(promote, x))
2185
+ norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
2186
+ norm = norm**0.5
2187
+ norm = norm.clamp(min=clip_at)
2188
+ return torch._foreach_div(x, norm)
2189
+
2190
+
2191
+ @decorator_knowngood
2192
+ def _compilable_global_l2norm_clip_(x, clip_at):
2193
+ x = list(map(promote, x))
2194
+ norm = sum([x.square().sum() for x in x])
2195
+ norm = norm**0.5
2196
+ norm = norm.clamp(min=clip_at)
2197
+ return torch._foreach_div(x, norm)
2198
+
2199
+
2200
+ def global_rmsnorm_clip(x, clip_at: float = 1.0):
2201
+ x = list_guard(x)
2202
+ return _compilable_global_rmsnorm_clip_(x, clip_at)
2203
+
2204
+
2205
+ def global_l2norm_clip(x, clip_at: float = 1.0):
2206
+ x = list_guard(x)
2207
+ return _compilable_global_rmsnorm_clip_(x, clip_at)
2208
+
2209
+
1735
2210
  def rmsnorm_normalize_(x, clip_at: float = 1e-6):
1736
2211
  x = list_guard(x)
1737
2212
  return _compilable_rmsnorm_clip_(x, clip_at)
@@ -1809,6 +2284,17 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1809
2284
  _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1810
2285
 
1811
2286
 
2287
+ @decorator_knowngood
2288
+ def _compilable_weight_decay_to_init_(p, init, weight_decay):
2289
+ _lerp(p, promote(init), 1 - weight_decay)
2290
+
2291
+
2292
+ def weight_decay_to_init_(p, init, weight_decay):
2293
+ p, init = list_guard(p, init)
2294
+ weight_decay = scalar_guard(weight_decay, p[0])
2295
+ _compilable_weight_decay_to_ema_(p, init, weight_decay)
2296
+
2297
+
1812
2298
  @decorator_knowngood
1813
2299
  def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1814
2300
  ema32 = _lerp(ema, p, ema_decay)
@@ -1867,35 +2353,25 @@ def triu_to_line(Q_list: List[Tensor]):
1867
2353
  if q.dim() < 2:
1868
2354
  out.append((None, q))
1869
2355
  else:
1870
- out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
2356
+ out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
1871
2357
  return out
1872
2358
 
1873
2359
 
1874
- def _triu_shape(numel):
1875
- n = int((2 * numel) ** 0.5)
1876
- assert n * (n + 1) == 2 * numel
1877
- return n, n
1878
-
1879
-
1880
- @decorator
1881
- def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
2360
+ @decorator_knowngood
2361
+ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
1882
2362
  new = []
1883
2363
  for shape, q in Q_list:
1884
2364
  if shape is not None:
1885
- shape = _triu_shape(q.numel())
1886
- x = torch.zeros(shape, device=q.device, dtype=q.dtype)
1887
- x[tuple(torch.triu_indices(*shape, device=q.device))] = q
1888
- q = x
2365
+ x, y = torch.triu_indices(*shape, device=q.device)
2366
+ q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
2367
+ q_mat[x, y] = q
2368
+ if symmetric_output:
2369
+ q_mat[y, x] = q
2370
+ q = q_mat
1889
2371
  new.append(q)
1890
2372
  return new
1891
2373
 
1892
2374
 
1893
- def update_triu_(q_state, materialised):
1894
- for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
1895
- assert shape0 == shape1
1896
- copy_stochastic_(q, m)
1897
-
1898
-
1899
2375
  _warned = set()
1900
2376
 
1901
2377
 
@@ -1918,52 +2394,118 @@ def psgd_should_update(
1918
2394
  return int(group[name]) > int(cumulative_prob)
1919
2395
 
1920
2396
 
2397
+ @functools.lru_cache(maxsize=None)
2398
+ def cached_precond_grad_expr(Q_dim, grad_dim):
2399
+ expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
2400
+ expr = ",".join(expr)
2401
+ grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
2402
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
2403
+ return f"{expr},{grad_expr}->{out_expr}"
2404
+
2405
+
1921
2406
  @decorator_knowngood
1922
2407
  def precond_grad_cached_(
1923
- expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
2408
+ ea: Tensor,
2409
+ cached_q: List[Tensor],
2410
+ caution: bool = False,
2411
+ grad: Optional[Tensor] = None,
2412
+ cast: bool = True,
1924
2413
  ):
1925
2414
  if caution:
1926
2415
  ea = _compilable_cautioning(grad, ea)
1927
2416
  md = min_dtype(list(cached_q) + [ea])
1928
2417
  args = [q.to(md) for q in cached_q]
1929
2418
  args = args + [ea.to(md)]
1930
- new = torch.einsum(expr, *args)
2419
+ expr = cached_precond_grad_expr(ndim_tuple(cached_q), grad.ndim)
2420
+ new = compiled_einsum(expr, *args)
1931
2421
  if cast:
1932
2422
  return new.to(ea.dtype)
1933
2423
  return new
1934
2424
 
1935
2425
 
2426
+ TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
2427
+
2428
+
1936
2429
  @decorator_knowngood
1937
- def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1938
- precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
2430
+ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
2431
+ precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
1939
2432
  update_param_(param, precond, lr, decay, caution=False)
1940
2433
 
1941
2434
 
1942
- def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
2435
+ def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
1943
2436
  lr = scalar_guard(lr, param[0])
1944
- _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
2437
+ _compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
2438
+
2439
+
2440
+ @functools.lru_cache(maxsize=None)
2441
+ def precond_grad_expr(Q_dim, grad_dim):
2442
+ expr = [
2443
+ f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
2444
+ ]
2445
+ expr = ",".join(expr)
2446
+ grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
2447
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
2448
+ return f"{expr},{grad_expr}->{out_expr}"
1945
2449
 
1946
2450
 
1947
2451
  @decorator_knowngood
1948
- def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
2452
+ def psgd_precond_grad(
2453
+ ea: Tensor,
2454
+ preconds: TriuOrLine,
2455
+ caution: bool = False,
2456
+ grad: Optional[Tensor] = None,
2457
+ store_triu_as_line: bool = False,
2458
+ symmetric_output: bool = False,
2459
+ ):
1949
2460
  if caution:
1950
2461
  ea = _compilable_cautioning(grad, ea)
2462
+ if store_triu_as_line:
2463
+ preconds = line_to_triu(preconds, symmetric_output)
1951
2464
  md = min_dtype(list(preconds) + [ea])
1952
2465
  args = [q.to(md) for q in preconds]
1953
- args = args + args + [ea.to(md)]
1954
- new = torch.einsum(expr, *args)
2466
+ expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
2467
+ new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
1955
2468
  return new.to(ea.dtype)
1956
2469
 
1957
2470
 
1958
2471
  @decorator_knowngood
1959
- def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1960
- precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
2472
+ def _compilable_fused_psgd_precond_grad(
2473
+ ea: Tensor,
2474
+ param,
2475
+ lr,
2476
+ grad,
2477
+ decay,
2478
+ caution,
2479
+ preconds: TriuOrLine,
2480
+ store_triu_as_line: bool = False,
2481
+ symmetric_output: bool = False,
2482
+ ):
2483
+ precond = psgd_precond_grad(
2484
+ ea,
2485
+ preconds,
2486
+ caution=caution,
2487
+ grad=grad,
2488
+ store_triu_as_line=store_triu_as_line,
2489
+ symmetric_output=symmetric_output,
2490
+ )
1961
2491
  update_param_(param, precond, lr, decay, caution=False, grad=grad)
1962
2492
 
1963
2493
 
1964
- def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
2494
+ def fused_psgd_precond_grad(
2495
+ ea: Tensor,
2496
+ param,
2497
+ lr,
2498
+ grad,
2499
+ decay,
2500
+ caution,
2501
+ preconds: TriuOrLine,
2502
+ store_triu_as_line: bool = False,
2503
+ symmetric_output: bool = False,
2504
+ ):
1965
2505
  lr = scalar_guard(lr, param[0])
1966
- _compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
2506
+ _compilable_fused_psgd_precond_grad(
2507
+ ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
2508
+ )
1967
2509
 
1968
2510
 
1969
2511
  @decorator_knowngood
@@ -2015,7 +2557,15 @@ def caution(g, update):
2015
2557
  return _compilable_cautioning(g, update)
2016
2558
 
2017
2559
 
2018
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_start=1000):
2560
+ def _inner_precond_update_prob_schedule(
2561
+ n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
2562
+ ):
2563
+ return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
2564
+
2565
+
2566
+ def precond_update_prob_schedule(
2567
+ max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
2568
+ ):
2019
2569
  """Anneal preconditioner update probability during beginning of training.
2020
2570
 
2021
2571
  PSGD benefits from more preconditioner updates at the beginning of training,
@@ -2026,11 +2576,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
2026
2576
  `min_prob` by ~4000 steps. Default settings work very well for most models and
2027
2577
  training regimes.
2028
2578
  """
2029
-
2030
- def _schedule(n):
2031
- return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
2032
-
2033
- return _schedule
2579
+ return functools.partial(
2580
+ _inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
2581
+ )
2034
2582
 
2035
2583
 
2036
2584
  def merge_group(group, *tensors):
@@ -2164,3 +2712,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
2164
2712
  def disable_caution_scaling():
2165
2713
  global _compilable_cautioning
2166
2714
  _compilable_cautioning = _compilable_caution_no_scale
2715
+
2716
+
2717
+ @decorator_knowngood
2718
+ def sam_step(parameters, ball_size, adaptive: bool = True):
2719
+ old_params = []
2720
+ for p in parameters:
2721
+ old_params.append(p.detach().clone())
2722
+ grad = promote(p.grad)
2723
+ if adaptive:
2724
+ grad = grad * promote(p).square()
2725
+ stochastic_add_(p.data, grad, ball_size)
2726
+ p.grad.zero_()
2727
+ return old_params