heavyball 1.7.2__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,28 +1,28 @@
1
+ import collections
1
2
  import contextlib
2
3
  import functools
3
4
  import gc
4
5
  import inspect
5
6
  import math
7
+ import pickle
6
8
  import random
7
9
  import re
8
10
  import string
9
11
  import warnings
10
- from typing import Callable, List, Optional, Tuple, Union
12
+ from typing import Callable, List, Literal, Optional, Tuple, Union
11
13
 
12
14
  import numpy as np
13
15
  import torch
14
16
  from torch import Tensor
15
- from torch._dynamo import config
16
17
  from torch._dynamo.exc import TorchDynamoException
17
18
  from torch.backends import cudnn, opt_einsum
18
19
  from torch.utils._pytree import tree_map
19
20
 
20
- config.cache_size_limit = 2**16
21
-
22
21
  compile_mode = "max-autotune-no-cudagraphs"
23
22
  dynamic = False
24
23
  compile_mode_recommended_to_none = None
25
- zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
24
+ zeroth_power_mode = "newtonschulz"
25
+ precise_zeroth_power_mode = "qr" # or svd
26
26
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
27
  _cudnn_double_backward_pattern = re.compile(
28
28
  r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
@@ -68,6 +68,16 @@ def decorator_knowngood(func: Callable, fullgraph: bool = True):
68
68
  einsum_base = string.ascii_lowercase
69
69
 
70
70
 
71
+ @decorator_knowngood
72
+ def compiled_einsum(expr, *args):
73
+ """
74
+ this is necessary to avoid the slowdown introduced by uncompiled einsum
75
+ uncompiled einsum is twice as slow if we add three 1-sized dimensions
76
+ for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
77
+ """
78
+ return torch.einsum(expr, *args)
79
+
80
+
71
81
  @decorator_knowngood
72
82
  def _compilable_schedule_free_(
73
83
  p: List[Tensor],
@@ -122,6 +132,47 @@ def schedule_free_(
122
132
  return weight_sum
123
133
 
124
134
 
135
+ @decorator_knowngood
136
+ def _compilable_msam(
137
+ lr: Tensor,
138
+ beta1: Tensor,
139
+ param: List[Tensor],
140
+ z: List[Tensor],
141
+ update: List[Tensor],
142
+ grad: List[Tensor],
143
+ exp_avg: List[Tensor],
144
+ caution: bool,
145
+ decay: Tensor,
146
+ sam_step_size: Tensor,
147
+ ):
148
+ exp_avg32 = _lerp(exp_avg, update, beta1)
149
+ for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
150
+ u_ = u_.view_as(z_)
151
+ z32_ = promote(z_)
152
+ if caution:
153
+ u_ = _compilable_cautioning(promote(g_), u_)
154
+ z32_ = z32_ * (1 - decay * lr) + u_ * -lr
155
+ copy_stochastic_(z_, z32_)
156
+ copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
157
+
158
+
159
+ def msam_(
160
+ lr: float,
161
+ beta1: float,
162
+ param: List[Tensor],
163
+ z: List[Tensor],
164
+ update: List[Tensor],
165
+ grad: List[Tensor],
166
+ exp_avg: List[Tensor],
167
+ caution: bool,
168
+ weight_decay: float,
169
+ sam_step_size: float,
170
+ ):
171
+ param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
172
+ lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
173
+ _compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
174
+
175
+
125
176
  def append_or_extend(base, new):
126
177
  if isinstance(new, list):
127
178
  base.extend(new)
@@ -161,7 +212,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
161
212
  new_shape = [grad.shape[0], *new_shape[::-1]]
162
213
  new_grad = grad.reshape(new_shape)
163
214
  if not split:
164
- return new_grad
215
+ return new_grad.to(memory_format=torch.contiguous_format).contiguous()
165
216
 
166
217
  grads = [new_grad]
167
218
  for i, sh in reversed(list(enumerate(new_shape[:]))):
@@ -172,7 +223,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
172
223
  continue
173
224
  grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
174
225
  if len(grads) == 1:
175
- return new_grad
226
+ return new_grad.to(memory_format=torch.contiguous_format).contiguous()
176
227
  new_grads = []
177
228
  for g in grads:
178
229
  append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
@@ -279,16 +330,29 @@ def clean():
279
330
 
280
331
 
281
332
  def _ignore_warning(msg):
282
- warnings.filterwarnings("ignore", f".*{msg}.*")
333
+ warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
283
334
 
284
335
 
285
- def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
336
+ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
337
+ import opt_einsum as _opt_einsum
338
+
286
339
  cudnn.benchmark = True
287
340
  cudnn.deterministic = False
288
341
  cudnn.benchmark_limit = benchmark_limit
289
342
  torch.use_deterministic_algorithms(False)
290
343
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
291
- opt_einsum.set_flags(True, einsum_strategy)
344
+ opt_einsum.set_flags(True)
345
+ if einsum_strategy == "heavyball":
346
+ opt_einsum.strategy = "auto-hq"
347
+ choices = _opt_einsum.paths._AUTO_HQ_CHOICES
348
+ for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
349
+ if isinstance(fn, int):
350
+ fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
351
+ for i in range(max(choices.keys()), max_val):
352
+ if i not in choices:
353
+ choices[i] = fn
354
+ else:
355
+ opt_einsum.strategy = einsum_strategy
292
356
 
293
357
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
294
358
  _ignore_warning(
@@ -297,6 +361,9 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
297
361
  _ignore_warning(
298
362
  "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
299
363
  )
364
+ _ignore_warning(
365
+ "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
366
+ )
300
367
 
301
368
 
302
369
  @decorator
@@ -316,15 +383,6 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
316
383
  return X.to(G.dtype)
317
384
 
318
385
 
319
- def ortho(x):
320
- if zeroth_power_mode == "qr":
321
- return torch.linalg.qr(x).Q
322
- if zeroth_power_mode == "svd":
323
- u, _s, v = torch.linalg.svd(x)
324
- return u @ v.T
325
- raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
326
-
327
-
328
386
  @decorator_knowngood
329
387
  def _compilable_heavyball_momentum_(state, grad, beta):
330
388
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
@@ -377,7 +435,7 @@ def _compilable_grafting(magnitude, direction):
377
435
 
378
436
 
379
437
  @decorator_knowngood
380
- def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
438
+ def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
381
439
  if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
382
440
  y = zeropower_via_newtonschulz5(x, 5)
383
441
  elif mode == "qr":
@@ -395,9 +453,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
395
453
  y = _compilable_grafting(x, y)
396
454
  else:
397
455
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
456
+ if out is None:
457
+ return y
458
+
398
459
  set_(out, y)
399
460
 
400
461
 
462
+ def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
463
+ return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
464
+
465
+
401
466
  @decorator_knowngood
402
467
  def _compilable_scatter_set(target, source, index):
403
468
  target[:] = source.contiguous()[index].reshape_as(target)
@@ -413,6 +478,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
413
478
  :param Q: List of current eigenbases (updated in-place to Q_new).
414
479
  :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
415
480
  """
481
+ if exp_avg.dim() == 0: # preconditioning doesn't make sense here
482
+ Q.clear()
483
+ return
484
+
416
485
  if isinstance(Q, list) and not Q:
417
486
  return
418
487
 
@@ -430,10 +499,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
430
499
  q_old = promote(q.data)
431
500
 
432
501
  tmp = m @ q_old
433
- est_eig = torch.einsum("ij,ij->j", q_old, tmp)
502
+ est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
434
503
  sort_idx = torch.argsort(est_eig, descending=True)
435
504
 
436
- tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
505
+ tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
437
506
  new_qs.append(tmp)
438
507
 
439
508
  if exp_avg is None:
@@ -453,7 +522,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
453
522
  out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
454
523
 
455
524
  subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
456
- exp_avg_new = torch.einsum(
525
+ exp_avg_new = compiled_einsum(
457
526
  subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
458
527
  )
459
528
  copy_stochastic_(exp_avg, exp_avg_new)
@@ -568,6 +637,19 @@ def scalar_guard(*args):
568
637
  return out
569
638
 
570
639
 
640
+ def broadcastable_list_guard(*xs):
641
+ xs = list_guard(*xs)
642
+ for x in xs:
643
+ if isinstance(x[0], Tensor):
644
+ ref = x[0]
645
+ break
646
+ else:
647
+ raise ValueError("No tensor-valued input given")
648
+ xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
649
+ max_len = max(len(x) for x in xs)
650
+ return [x if len(x) > 1 else x * max_len for x in xs]
651
+
652
+
571
653
  @decorator_knowngood
572
654
  def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
573
655
  for x_, y_ in zip(x, y):
@@ -577,7 +659,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
577
659
 
578
660
 
579
661
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
580
- x, y = list_guard(x, y)
662
+ x, y = broadcastable_list_guard(x, y)
581
663
  alpha = scalar_guard(alpha, x[0])
582
664
  _compilable_stochastic_add_(x, y, alpha)
583
665
 
@@ -591,7 +673,7 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
591
673
 
592
674
 
593
675
  def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
594
- x, y = list_guard(x, y)
676
+ x, y = broadcastable_list_guard(x, y)
595
677
  alpha, divisor = scalar_guard(alpha, divisor, x[0])
596
678
  _compilable_stochastic_add_divide_(x, y, alpha, divisor)
597
679
 
@@ -605,7 +687,7 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
605
687
 
606
688
 
607
689
  def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
608
- x, y = list_guard(x, y)
690
+ x, y = broadcastable_list_guard(x, y)
609
691
  _compilable_stochastic_multiply_(x, y)
610
692
 
611
693
 
@@ -624,7 +706,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
624
706
  b = einsum_base[idx]
625
707
  g0 = einsum_base[: grad.dim()]
626
708
  g1 = g0.replace(b, b.upper())
627
- outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
709
+ outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
628
710
  stochastic_lerp_(m, outer_product, 1 - beta)
629
711
 
630
712
 
@@ -706,7 +788,7 @@ def project(grad, Q, back: bool):
706
788
  preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
707
789
  if preconditioners:
708
790
  out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
709
- out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
791
+ out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
710
792
  grad = out.to(grad.dtype)
711
793
  return grad
712
794
 
@@ -714,24 +796,28 @@ def project(grad, Q, back: bool):
714
796
  @contextlib.contextmanager
715
797
  def patch_backward():
716
798
  @contextlib.contextmanager
717
- def _inner(module):
799
+ def patch_module(module):
718
800
  original = module.backward
719
-
720
- signature = inspect.signature(original)
721
-
722
- def patched_backward(*args, **kwargs):
723
- new_kwargs = signature.bind(*args)
724
- new_kwargs.apply_defaults()
725
- new_kwargs = new_kwargs.arguments
726
- new_kwargs.update(kwargs)
727
- new_kwargs["create_graph"] = True
728
- return original(**new_kwargs)
729
-
730
- module.backward = patched_backward
731
- yield
732
- module.backward = original
733
-
734
- with _inner(torch.Tensor), _inner(torch.autograd):
801
+ try:
802
+ signature = inspect.signature(original)
803
+
804
+ @functools.wraps(original)
805
+ def patched_backward(*args, **kwargs):
806
+ new_kwargs = signature.bind(*args)
807
+ new_kwargs.apply_defaults()
808
+ new_kwargs = new_kwargs.arguments
809
+ new_kwargs.update(kwargs)
810
+ new_kwargs["create_graph"] = True
811
+ return original(**new_kwargs)
812
+
813
+ module.backward = patched_backward
814
+ yield
815
+ finally:
816
+ module.backward = original
817
+
818
+ with contextlib.ExitStack() as stack:
819
+ stack.enter_context(patch_module(torch.Tensor))
820
+ stack.enter_context(patch_module(torch.autograd))
735
821
  yield
736
822
 
737
823
 
@@ -743,6 +829,9 @@ class ExactHVPFailed(ValueError):
743
829
  pass
744
830
 
745
831
 
832
+ use_default = object()
833
+
834
+
746
835
  class StatefulOptimizer(torch.optim.Optimizer):
747
836
  """
748
837
  finite_differences saves memory, but needs more compute. (Alternative is true HVP)
@@ -755,7 +844,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
755
844
  compile_step: bool = False
756
845
  hessian_approx: bool = False
757
846
  precond_schedule: Union[Callable, float, None] = None
758
- stochastic_schedule: bool = False
847
+ stochastic_schedule: bool | Literal[use_default] = use_default
759
848
  finite_differences: bool = False
760
849
  fallback_to_finite_differences: bool = True
761
850
  _fallback_enabled: bool = False
@@ -765,18 +854,61 @@ class StatefulOptimizer(torch.optim.Optimizer):
765
854
  super().__init__(params, {**defaults, "foreach": foreach})
766
855
  self.use_ema = use_ema
767
856
  self.mapping = {}
768
- self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
769
- self._precond_rng = random.Random(0x12312)
857
+ self.mapping_inverse = {}
858
+
859
+ if self.stochastic_schedule is use_default:
860
+ stochastic_schedule = None
861
+ for group in self.param_groups:
862
+ new = group.get("stochastic_schedule", stochastic_schedule)
863
+ if stochastic_schedule is not None and new != stochastic_schedule:
864
+ raise ValueError("All parameter groups must have the same stochastic_schedule.")
865
+ stochastic_schedule = new
866
+ self.stochastic_schedule = stochastic_schedule
867
+
868
+ self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
869
+ self.precond_rng = random.Random(0x12312)
770
870
  self._is_preconditioning = None
771
871
 
772
872
  if self.hessian_approx and self.compile_step:
773
873
  raise ValueError("Hessian approximation can't be used with compile_step.")
774
874
 
875
+ self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
876
+ self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
877
+ self._init_mapping()
878
+
879
+ def _store_stats(self, state_dict: dict[str, any]):
880
+ state_dict["heavyball"] = {
881
+ "inner_group": self.inner_group,
882
+ "precond_rng": pickle.dumps(self.precond_rng),
883
+ "use_ema": self.use_ema,
884
+ "ema_decay": self.ema_decay,
885
+ "compile_step": self.compile_step,
886
+ "hessian_approx": self.hessian_approx,
887
+ "precond_schedule": pickle.dumps(self.precond_schedule),
888
+ "stochastic_schedule": self.stochastic_schedule,
889
+ "fallback_to_finite_differences": self.fallback_to_finite_differences,
890
+ "_fallback_enabled": self._fallback_enabled,
891
+ "hvp_interval": self.hvp_interval,
892
+ }
893
+
894
+ def _load_stats(self, state_dict):
895
+ sd = state_dict.pop("heavyball", {})
896
+ for k, v in sd.items():
897
+ if k in ("precond_rng", "precond_schedule"):
898
+ v = pickle.loads(v)
899
+ setattr(self, k, v)
900
+
775
901
  def get_groups(self, group):
776
902
  return [group]
777
903
 
778
- def state_(self, arg: Tensor):
779
- return self.state[arg]
904
+ @functools.lru_cache(maxsize=None)
905
+ def state_(self, arg: Tensor, fail: bool = True):
906
+ if not fail and arg not in self.mapping:
907
+ return {}
908
+ state_param, index = self.mapping_inverse[arg]
909
+ if state_param not in self.state:
910
+ self.state[state_param] = collections.defaultdict(dict)
911
+ return self.state[state_param][index]
780
912
 
781
913
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
782
914
  for p, g in zip(p_list, g_list):
@@ -786,6 +918,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
786
918
  old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
787
919
  mars_correction(g_list, old_gs, mars_gamma, beta)
788
920
 
921
+ def _init_mapping(self, group: dict | None = None):
922
+ if group is None:
923
+ for group in self.param_groups:
924
+ self._init_mapping(group)
925
+ return
926
+
927
+ for p in group["params"]:
928
+ if p not in self.mapping:
929
+ self.mapping[p] = p_views = merge_group(group, p)
930
+ for i, pv in enumerate(p_views):
931
+ self.mapping_inverse[pv] = (p, i)
932
+
789
933
  def split_p_and_g_in_group(
790
934
  self,
791
935
  group: dict,
@@ -809,6 +953,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
809
953
  p_views = self.mapping[p]
810
954
  else:
811
955
  self.mapping[p] = p_views = merge_group(group, p)
956
+ for i, pv in enumerate(p_views):
957
+ self.mapping_inverse[pv] = (p, i)
812
958
 
813
959
  vector = getattr(p, "vector", None)
814
960
  hessian_vector = getattr(p, "hessian_vector", None)
@@ -957,8 +1103,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
957
1103
  raise ValueError("Hessian approximation requires a closure.")
958
1104
  return None
959
1105
 
960
- step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
961
- if not hessian_approx or step % self.hvp_interval == 0:
1106
+ step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
1107
+ if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
962
1108
  with torch.enable_grad():
963
1109
  loss = closure()
964
1110
  return loss
@@ -997,12 +1143,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
997
1143
  if self.precond_schedule is None:
998
1144
  self._is_preconditioning = False
999
1145
  else:
1000
- self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
1146
+ self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
1001
1147
  loss = self._handle_closure(closure)
1002
1148
 
1003
1149
  # we assume that parameters are constant and that there are no excessive recompiles
1004
1150
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
1005
1151
  for group in self.param_groups:
1152
+ if "param_count" not in group:
1153
+ group["param_count"] = sum(p.numel() for p in group["params"])
1006
1154
  group["is_preconditioning"] = self._is_preconditioning
1007
1155
  self._step(group)
1008
1156
  if self.use_ema:
@@ -1306,83 +1454,115 @@ def stable_exp(x: Tensor):
1306
1454
  return torch.where(x > 0, 1 / (-x).exp(), x.exp())
1307
1455
 
1308
1456
 
1457
+ def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
1458
+ # ln(mean(x ** pow) ** (1 / pow / 2))
1459
+ normalization = math.log(x.numel())
1460
+ x = x.double()
1461
+ x = x.abs()
1462
+ x = x.clamp(min=eps)
1463
+ x = x.log()
1464
+ x = x * pow
1465
+ x = x.flatten()
1466
+ x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
1467
+ x = x - normalization # sum -> mean (divide by x.numel() in log space)
1468
+ return x / pow / 2
1469
+
1470
+
1309
1471
  @decorator_knowngood
1310
1472
  def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
1311
1473
  # 1 / (mean(x ** pow) ** (1 / pow / 2))
1312
- log_x = x.double().abs().clamp(min=eps).log()
1313
- log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
1314
- return stable_exp(-log_mean_x_pow / pow / 2)
1474
+ return stable_exp(-_lse_mean(x, pow, eps))
1315
1475
 
1316
1476
 
1317
1477
  @decorator_knowngood
1318
1478
  def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
1319
1479
  # mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
1320
- log_x = x.double().abs().clamp(min=eps).log()
1321
- log_y = y.double().abs().clamp(min=eps).log()
1322
-
1323
- x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
1324
- x_normed = x_normed / pow0 / 2
1480
+ return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
1325
1481
 
1326
- y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
1327
- y_normed = y_normed / pow1 / 2
1328
1482
 
1329
- return stable_exp(x_normed - y_normed)
1483
+ class PrecondInitError(ValueError):
1484
+ pass
1330
1485
 
1331
1486
 
1332
- def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1487
+ def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
1333
1488
  automatic_scale = True
1334
1489
  manual_hint = " Set it manually using `precond_init_scale=0.1`"
1490
+ scale_scale = 1 if scale_scale is None else scale_scale
1335
1491
 
1336
1492
  if scale is not None:
1337
1493
  automatic_scale = False
1338
1494
  warn_once(
1339
1495
  "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
1340
1496
  )
1341
- if scale_scale is not None and scale_scale != 1:
1497
+ if scale_scale != 1:
1342
1498
  warn_once(
1343
- "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
1499
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
1500
+ )
1501
+ if scale_power is not None:
1502
+ warn_once(
1503
+ "precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
1344
1504
  )
1345
1505
  elif hessian_vector is None:
1346
1506
  scale = mean_root(grad, 4) * scale_scale
1347
1507
  else:
1348
1508
  scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1349
1509
 
1510
+ if automatic_scale:
1511
+ scale_power = 0.5 if scale_power is None else scale_power
1512
+ scale = scale**scale_power
1513
+
1350
1514
  if isinstance(scale, torch.Tensor):
1351
1515
  scale = scale.item() # slow, but necessary
1352
1516
 
1353
1517
  if np.isfinite(scale):
1354
- if scale > scale_max or scale < 1 / scale_max: # fallthrough to later checks
1518
+ if scale > scale_max: # fallthrough to later checks
1355
1519
  warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1356
1520
  else:
1357
1521
  return scale
1358
1522
 
1359
1523
  if not automatic_scale:
1360
- raise ValueError("The manually set precond_init_scale is not finite")
1524
+ raise PrecondInitError("The manually set precond_init_scale is not finite")
1361
1525
 
1362
1526
  for x in (grad, hessian_vector, vector):
1363
1527
  if x is None:
1364
1528
  continue
1365
- if torch.allclose(x, torch.zeros_like(x)).item():
1366
- raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1529
+ if torch.allclose(x, torch.zeros_like(x)):
1530
+ raise PrecondInitError(
1531
+ f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
1532
+ )
1367
1533
  if not torch.isfinite(x).all().item():
1368
- raise ValueError("Grad or HVP is not finite")
1534
+ raise PrecondInitError("Grad or HVP is not finite")
1369
1535
 
1370
1536
  if np.isfinite(scale):
1371
1537
  return scale
1372
1538
 
1373
- raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1539
+ raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
1374
1540
 
1375
1541
 
1376
- def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
1377
- scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1378
- U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1379
- V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1542
+ def init_lra(
1543
+ grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
1544
+ ):
1545
+ # "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
1546
+ # https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
1547
+ scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
1548
+ uv_scale = (param_count * (rank + eps)) ** -0.5
1549
+ U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
1550
+ V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
1380
1551
  d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
1381
1552
  return U, V, d
1382
1553
 
1383
1554
 
1384
1555
  def init_Q_exprs(
1385
- grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
1556
+ grad,
1557
+ scale,
1558
+ scale_scale,
1559
+ scale_power,
1560
+ max_size,
1561
+ min_ndim_triangular,
1562
+ memory_save_mode,
1563
+ hessian_vector,
1564
+ vector,
1565
+ dtype=None,
1386
1566
  ):
1387
1567
  """
1388
1568
  For a scalar or tensor `grad`, we initialize its preconditioner Q and
@@ -1391,21 +1571,13 @@ def init_Q_exprs(
1391
1571
  precond init scale computation from
1392
1572
  https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
1393
1573
  """
1394
- scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1395
- letters = string.ascii_lowercase + string.ascii_uppercase
1574
+ scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
1396
1575
  dtype = dtype if dtype is not None else grad.dtype
1397
1576
  shape = grad.shape
1398
1577
 
1399
1578
  if len(shape) == 0: # scalar
1400
1579
  Q = [scale * torch.ones_like(grad, dtype=dtype)]
1401
- exprA = ",->"
1402
- exprGs = [",->"]
1403
- exprP = ",,->"
1404
- return [Q, (exprA, tuple(exprGs), exprP)]
1405
-
1406
- # Tensor
1407
- if len(shape) > 13:
1408
- raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
1580
+ return Q
1409
1581
 
1410
1582
  scale = scale ** (1 / len(shape))
1411
1583
 
@@ -1418,6 +1590,9 @@ def init_Q_exprs(
1418
1590
  sorted_shape = sorted(shape)
1419
1591
  if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
1420
1592
  dim_diag[_max_idx(shape)] = True
1593
+ elif memory_save_mode == "one_triu":
1594
+ shape_ranks = np.argsort(np.argsort(shape)) # ranks
1595
+ dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
1421
1596
  elif memory_save_mode == "all_diag":
1422
1597
  dim_diag = [True for _ in shape]
1423
1598
  else:
@@ -1427,66 +1602,90 @@ def init_Q_exprs(
1427
1602
  )
1428
1603
 
1429
1604
  Q = []
1430
- piece1A, piece2A, piece3A = ([], "", "")
1431
- exprGs = []
1432
- piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
1433
1605
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
1434
1606
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
1435
1607
  # use diagonal matrix as preconditioner for this dim
1436
1608
  Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
1437
-
1438
- piece1A.append(letters[i])
1439
- piece2A = piece2A + letters[i]
1440
- piece3A = piece3A + letters[i]
1441
- piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1442
- subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
1443
- exprGs.append(subscripts)
1444
- piece1P.append(letters[i + 13])
1445
- piece2P.append(letters[i + 13])
1446
- piece3P = piece3P + letters[i + 13]
1447
- piece4P = piece4P + letters[i + 13]
1448
1609
  else:
1449
1610
  # use triangular matrix as preconditioner for this dim
1450
1611
  Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
1451
- piece1A.append(letters[i] + letters[i + 13])
1452
- piece2A = piece2A + letters[i + 13]
1453
- piece3A = piece3A + letters[i]
1454
- piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1455
- piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1456
- subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
1457
- exprGs.append(subscripts)
1458
- a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1459
- piece1P.append(a + b)
1460
- piece2P.append(a + c)
1461
- piece3P = piece3P + c
1462
- piece4P = piece4P + b
1463
-
1464
- exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
1465
- exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
1466
- return [Q, (exprA, tuple(exprGs), exprP)]
1612
+ return Q
1467
1613
 
1468
1614
 
1469
- @decorator
1470
- def psgd_balance_Q(Q_in):
1471
- norms = torch.stack([q.norm(float("inf")) for q in Q_in])
1472
- geometric_mean = norms.log().mean().exp()
1473
- norms = geometric_mean / norms
1474
- torch._foreach_mul_(Q_in, list(norms))
1615
+ @decorator_knowngood
1616
+ def psgd_balance_Q(Q):
1617
+ norms = [promote(q.norm(float("inf"))).log() for q in Q]
1618
+ geometric_mean = sum([n for n in norms]) / len(Q)
1619
+ for q, n in zip(Q, norms):
1620
+ q *= (geometric_mean - n).exp()
1475
1621
 
1476
1622
 
1477
- @decorator
1478
- def psgd_balance_lra(U: Tensor, V: Tensor):
1479
- u_norm = promote(torch.linalg.vector_norm(U))
1480
- v_norm = promote(torch.linalg.vector_norm(V))
1481
- scale = (u_norm / v_norm) ** 0.5
1482
- U.div_(scale)
1483
- V.mul_(scale)
1623
+ @decorator_knowngood
1624
+ def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
1625
+ u_norm = sum(u.square().sum().double() for u in U)
1626
+ v_norm = sum(v.square().sum().double() for v in V)
1627
+ scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
1628
+ scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
1629
+ stochastic_multiply_(U, [1 / scale] * len(U))
1630
+ stochastic_multiply_(V, [scale] * len(V))
1631
+ return multi_flatten((U, 1), (V, 1), (d, 0))
1484
1632
 
1485
1633
 
1486
1634
  @decorator
1487
1635
  def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1488
1636
  dtype = min_dtype([U, V, x])
1489
- return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1637
+ return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1638
+
1639
+
1640
+ @decorator_knowngood
1641
+ def _compilable_d_step(
1642
+ d: Tensor,
1643
+ d_orig: List[Tensor],
1644
+ invQtv: Tensor,
1645
+ vector: Tensor,
1646
+ inverse_precond_vector: Tensor,
1647
+ hessian_vector: Tensor,
1648
+ precond_hessian_vector: Tensor,
1649
+ eps: Tensor,
1650
+ step: Tensor,
1651
+ delayed: bool,
1652
+ ):
1653
+ precond_hessian_vector = promote(precond_hessian_vector)
1654
+ hessian_vector = promote(hessian_vector)
1655
+ vector = promote(vector)
1656
+ inverse_precond_vector = promote(inverse_precond_vector)
1657
+ invQtv = promote(invQtv)
1658
+ inverse_precond_vector = invQtv - inverse_precond_vector
1659
+
1660
+ nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
1661
+
1662
+ """
1663
+ 1) Sketching
1664
+ 1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
1665
+ 1.2) reduced-precision selection of largest element (halves memory traffic)
1666
+ 2) Computation
1667
+ 2.1) select relevant indices
1668
+ 2.2) redo 1.1 in double precision for scalar values
1669
+ 2.3) return high-precision normalized step-size
1670
+ overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
1671
+ improving precision
1672
+ """
1673
+ a0 = promote(d) * precond_hessian_vector
1674
+ a1 = vector
1675
+ b0 = inverse_precond_vector / promote(d)
1676
+ b1 = hessian_vector
1677
+
1678
+ divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
1679
+ idx = divisor.bfloat16().flatten().argmax()
1680
+ a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
1681
+ b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
1682
+ divisor = (a * b).sqrt().clamp(min=eps)
1683
+ step = -step / divisor
1684
+
1685
+ # fused update(s)
1686
+ apply_flat_add(d_orig, nablaD, step)
1687
+ if not delayed:
1688
+ copy_stochastic_(d, promote(d) - nablaD * step)
1490
1689
 
1491
1690
 
1492
1691
  def update_lra_precond_(
@@ -1498,13 +1697,14 @@ def update_lra_precond_(
1498
1697
  eps: float,
1499
1698
  step: float,
1500
1699
  delayed: bool,
1700
+ precond_u: bool,
1501
1701
  ):
1502
1702
  """
1503
1703
  Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1504
1704
  """
1505
1705
  U_orig, V_orig, d_orig = U, V, d
1506
1706
 
1507
- U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
1707
+ U, V, d = _lra_flatten_and_balance(U, V, d)
1508
1708
 
1509
1709
  dtype = min_dtype([U, V, vector, hessian_vector])
1510
1710
  U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
@@ -1512,10 +1712,10 @@ def update_lra_precond_(
1512
1712
  eps = scalar_guard(eps, vector)
1513
1713
 
1514
1714
  Qh = low_rank_mm(U, V, d * hessian_vector)
1515
- Ph = d * low_rank_mm(V, U, Qh)
1715
+ Ph = low_rank_mm(V, U, Qh)
1516
1716
  rank = U.size(1)
1517
1717
 
1518
- VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
1718
+ VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
1519
1719
  I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
1520
1720
  IpVtU = I + VtU
1521
1721
  invQtv = vector / d
@@ -1533,47 +1733,39 @@ def update_lra_precond_(
1533
1733
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1534
1734
 
1535
1735
  invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1536
- invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1537
- invPv = invPv / d
1736
+ invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1538
1737
 
1539
- nablaD = Ph * hessian_vector - vector * invPv
1540
- divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
1541
- divisor = divisor.add(eps).sqrt().max()
1542
- d_step = step / divisor
1543
-
1544
- apply_flat_add(d_orig, d * nablaD, -d_step)
1738
+ eps, step = scalar_guard(eps, step, vector)
1739
+ _compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
1545
1740
 
1546
1741
  a, b = Qh, invQtv
1547
1742
 
1548
- precond_u = random.random() < 0.5 # update either U or V, not both at the same time
1549
1743
  precond = V if precond_u else U
1550
- atV = torch.einsum("b,br->r", a, precond) # o == one
1551
- btV = torch.einsum("b,br->r", b, precond)
1552
- atVVt = torch.einsum("r,br->b", atV, precond)
1553
- btVVt = torch.einsum("r,br->b", btV, precond)
1554
- precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
1744
+ atV = compiled_einsum("b,br->r", a, precond) # o == one
1745
+ btV = compiled_einsum("b,br->r", b, precond)
1746
+ atVVt = compiled_einsum("r,br->b", atV, precond)
1747
+ btVVt = compiled_einsum("r,br->b", btV, precond)
1748
+ precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
1555
1749
  if precond_u:
1556
- a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
1557
- b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
1750
+ a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
1751
+ b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
1558
1752
  else:
1559
- a = a + torch.einsum("br,r->b", V, atV)
1560
- b = b + torch.einsum("br,r->b", V, btV)
1561
- a = torch.einsum("b,r->br", a, atV)
1562
- b = torch.einsum("b,r->br", b, btV)
1753
+ a = a + compiled_einsum("br,r->b", V, atV)
1754
+ b = b + compiled_einsum("br,r->b", V, btV)
1755
+ a = compiled_einsum("b,r->br", a, atV)
1756
+ b = compiled_einsum("b,r->br", b, btV)
1563
1757
  apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
1564
-
1565
1758
  if not delayed:
1566
- stochastic_add_([d], [d * nablaD], -d_step)
1567
1759
  stochastic_add_([U if precond_u else V], [b - a], precond_step)
1568
1760
  return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1569
1761
 
1570
1762
 
1571
- def lra_precond(U, V, d, g):
1763
+ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
1572
1764
  """
1573
1765
  As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
1574
1766
  """
1575
- g = low_rank_mm(U, V, d * g)
1576
- return d * low_rank_mm(V, U, g)
1767
+ new_g = low_rank_mm(U, V, d * g)
1768
+ return d * low_rank_mm(V, U, new_g)
1577
1769
 
1578
1770
 
1579
1771
  @decorator_knowngood
@@ -1584,16 +1776,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
1584
1776
 
1585
1777
 
1586
1778
  @decorator_knowngood
1587
- def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
1588
- update = lra_precond(U, V, d, update)
1779
+ def _compilable_lra_update_(
1780
+ params: List[Tensor],
1781
+ update: List[Tensor],
1782
+ U: Tensor,
1783
+ V: Tensor,
1784
+ d: Tensor,
1785
+ lr: Tensor,
1786
+ decay: Tensor,
1787
+ caution: bool,
1788
+ grads: List[Tensor],
1789
+ ):
1790
+ update = lra_precond(U, V, d, flatten(update))
1589
1791
  start = 0
1590
1792
  update = update.flatten()
1591
- for p in params:
1793
+ for p, g in zip(params, grads):
1592
1794
  size = p.numel()
1593
- copy_stochastic_(p, update[start : start + size].view_as(p))
1795
+ update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
1594
1796
  start += size
1595
1797
 
1596
1798
 
1799
+ def apply_lra_update(
1800
+ params: List[Tensor],
1801
+ update: Tensor,
1802
+ U: Tensor,
1803
+ V: Tensor,
1804
+ d: Tensor,
1805
+ lr: float,
1806
+ decay: float,
1807
+ caution: bool,
1808
+ grads: List[Tensor],
1809
+ ):
1810
+ params, grads = list_guard(params, grads)
1811
+ lr, decay = scalar_guard(lr, decay, params[0])
1812
+ _compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
1813
+
1814
+
1597
1815
  @decorator_knowngood
1598
1816
  def apply_flat_update(params: List[Tensor], update: Tensor):
1599
1817
  start = 0
@@ -1604,6 +1822,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
1604
1822
  start += size
1605
1823
 
1606
1824
 
1825
+ @decorator_knowngood
1826
+ def zero_(x: List[Tensor]):
1827
+ for i in x:
1828
+ i.zero_()
1829
+
1830
+
1607
1831
  @decorator_knowngood
1608
1832
  def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1609
1833
  start = 0
@@ -1629,7 +1853,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1629
1853
  @decorator_knowngood
1630
1854
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1631
1855
  last_dim = x[0].shape[-remaining:] if remaining else []
1632
- return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
1856
+ return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
1857
+
1858
+
1859
+ @decorator_knowngood
1860
+ def multi_flatten(*xs: Tuple[List[Tensor], int]):
1861
+ return [flatten(x, i) for x, i in xs]
1633
1862
 
1634
1863
 
1635
1864
  @decorator_knowngood
@@ -1645,110 +1874,275 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1645
1874
 
1646
1875
  def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1647
1876
  md = min_dtype(args)
1648
- return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1877
+ return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1649
1878
 
1650
1879
 
1651
1880
  @decorator_knowngood
1652
1881
  def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
1653
1882
  triangular_qs = []
1883
+ conjB = promote(conjB)
1654
1884
  for i, q in enumerate(Qs):
1655
1885
  q = promote(q)
1656
1886
  if q.dim() <= 1:
1657
- shape = [1] * conjB.ndim
1658
- shape[i] = -1
1659
- conjB /= q.view(shape)
1887
+ if conjB.ndim == 0:
1888
+ conjB = conjB / q
1889
+ else:
1890
+ shape = [1] * conjB.ndim
1891
+ shape[i] = -1
1892
+ conjB = conjB / q.view(shape)
1660
1893
  else:
1661
1894
  triangular_qs.append((i, q))
1662
- return triangular_qs
1895
+ return triangular_qs, conjB
1663
1896
 
1664
1897
 
1665
1898
  @decorator_knowngood
1666
- def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int, new_shape: int):
1899
+ def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
1900
+ solved = solved.reshape(transposed_shape)
1901
+ solved = solved.transpose(-1, last_dim)
1667
1902
  solved = solved.reshape(original_shape)
1668
- solved.transpose(last_dim, -1)
1669
- return solved.reshape(new_shape).contiguous()
1903
+ solved = solved.transpose(-1, new_dim)
1904
+ return solved.contiguous(), solved.shape
1905
+
1906
+
1907
+ def ndim_tuple(Q: list[Tensor]) -> tuple:
1908
+ return tuple(q.ndim for q in Q)
1670
1909
 
1671
1910
 
1672
- def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1673
- order = G.dim()
1674
- if order > 1:
1675
- conjB = conjB.view_as(G).permute(*range(1, order), 0)
1676
- conjB = conjB.to(promote(G.dtype))
1911
+ def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1912
+ exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
1677
1913
  A = casted_einsum(exprA, *Q, G)
1678
1914
  solve = torch.compiler.disable(torch.linalg.solve_triangular)
1679
- original_shape = conjB.shape
1915
+ transposed_shape = original_shape = conjB.shape
1680
1916
  prev_i = -1
1681
- for i, tri_q in _psgd_calc_scalars_(Q, conjB):
1682
- conjB = _reshape_conjB(conjB, original_shape, prev_i, [-1, tri_q.size(0)])
1917
+ qs, conjB = _psgd_calc_scalars_(Q, conjB)
1918
+ for i, tri_q in qs:
1919
+ conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
1683
1920
  prev_i = i
1684
1921
  conjB = solve(tri_q, conjB, upper=True, left=False)
1685
- conjB = _reshape_conjB(conjB, original_shape, prev_i, original_shape)
1922
+ conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
1686
1923
  return A, conjB
1687
1924
 
1688
1925
 
1689
1926
  @decorator_knowngood
1690
- def _max_select(to_index: Tensor, to_argmax: Tensor):
1691
- idx = to_argmax.argmax()
1692
- return to_index.index_select(1, idx).flatten().contiguous()
1927
+ def _random_projection(x: Tensor, scale: Optional[Tensor]):
1928
+ if scale is None:
1929
+ scale = x.norm(float("inf")).clamp(min=1e-8)
1930
+ k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
1931
+ norm = x.square().sum(0)
1932
+ indices = torch.topk(norm, k, largest=True).indices
1933
+ return x.index_select(1, indices).contiguous() / scale, scale
1693
1934
 
1694
1935
 
1695
- def psgd_lb(A: Tensor, max_abs: Tensor):
1696
- A /= max_abs
1697
- x = _max_select(A, torch.einsum("ij,ij->j", A, A))
1698
- x = torch.einsum("i,ij->j", x, A)
1699
- x /= x.norm()
1700
- x = torch.einsum("j,kj->k", x, A)
1701
- x = x.norm()
1702
- x *= max_abs
1703
- return x
1936
+ def max_singular_value_exact(A, use_lobpcg: bool = False):
1937
+ try:
1938
+ if use_lobpcg:
1939
+ A = A @ A.T
1940
+ eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
1941
+ return eigval[0].sqrt()
1942
+ else:
1943
+ return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
1944
+ except torch.linalg.LinAlgError:
1945
+ return torch.zeros((), device=A.device, dtype=A.dtype)
1704
1946
 
1705
1947
 
1706
1948
  @decorator_knowngood
1707
- def _subtract_from_line_(state: Tensor, term: Tensor):
1708
- stochastic_add_([state], [triu_to_line([term])[0][1]], -1)
1949
+ def max_singular_value_power_iter(A: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
1950
+ """
1951
+ Rayleigh quotient of row with the largest norm + optional power iterations
1952
+ """
1953
+ x_norm, max_idx = A.norm(dim=1).max(dim=0)
1954
+ x = A.index_select(0, max_idx).flatten().contiguous()
1955
+ A = A / x_norm
1956
+ x = x / x_norm
1957
+ for _ in range(iterations):
1958
+ x = A.T.mv(A.mv(x)) # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
1959
+ x = x / x.norm()
1960
+ return (x @ A.T.mv(A.mv(x))).sqrt() * x_norm
1709
1961
 
1710
1962
 
1711
1963
  @decorator_knowngood
1712
- def _prescale_term_(term1: Tensor, fac: Tensor, norm: Tensor, lower_bound: Tensor):
1713
- out = term1.float().triu() * fac
1714
- out = out / torch.where(norm > 0, lower_bound, norm).clamp(tiny_bf16)
1715
- copy_stochastic_(term1, out)
1964
+ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
1965
+ """
1966
+ Adapted from @evanatyourservice
1967
+ """
1968
+ Y, max_abs = _random_projection(A, max_abs)
1969
+ Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
1970
+ Q = Q / max_abs
1971
+ Z = A.T @ Q
1972
+ W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
1973
+ sketch_norm = max_singular_value_exact(Z.T @ W)
1974
+ return sketch_norm * max_abs
1716
1975
 
1717
1976
 
1718
1977
  @decorator_knowngood
1719
- def _compilable_stochastic_multiply_div_(x: Tensor, fac: Tensor, y: Tensor, z: Tensor):
1720
- copy_stochastic_(x, promote(x) * promote(fac) * promote(y) / promote(z).clamp(min=tiny_bf16))
1978
+ def max_singular_value(
1979
+ A: Tensor, max_abs: Optional[Tensor], max_svd: int = 32, use_cholesky: bool = False, power_iter: int = 0
1980
+ ) -> Tensor:
1981
+ if min(A.shape) <= max_svd:
1982
+ return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
1983
+ if use_cholesky or power_iter < 0:
1984
+ return max_singular_value_cholesky(A, max_abs)
1985
+ return max_singular_value_power_iter(A, None, iterations=power_iter)
1986
+
1987
+
1988
+ @decorator_knowngood
1989
+ def _psgd_default_preconditioner_grad(
1990
+ terms: List[Tuple[Tensor, Tensor]],
1991
+ Q: List[Tensor],
1992
+ ) -> List[Tensor]:
1993
+ out = []
1994
+ for q, (x, y) in zip(Q, terms):
1995
+ x = promote(x)
1996
+ y = promote(y)
1997
+ update = x - y
1998
+ if q.ndim < 2:
1999
+ update = q * update
2000
+ else:
2001
+ update = (q @ update).triu()
2002
+ out.append(update)
2003
+ return out
1721
2004
 
1722
2005
 
1723
2006
  @decorator_knowngood
1724
- def _compilable_add_sub_(x: Tensor, y: Tensor):
1725
- x = promote(x)
1726
- y = promote(y)
1727
- return x - y, x + y
2007
+ def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
2008
+ if isinstance(Q[0], tuple):
2009
+ psgd_balance_Q([o[1] for o in Q])
2010
+ return line_to_triu(Q, symmetric_output)
2011
+ psgd_balance_Q(Q)
2012
+ return Q
2013
+
2014
+
2015
+ @functools.lru_cache(maxsize=None)
2016
+ def calcG_expr(q_dim, g_dim):
2017
+ exprs = []
2018
+ base = einsum_base[:g_dim]
2019
+ for i, q in enumerate(q_dim):
2020
+ new = list(base)
2021
+ if q == 2:
2022
+ new[i] = "Z"
2023
+ out = f"{base[i]}Z"
2024
+ else:
2025
+ out = base[i]
2026
+ exprs.append(f"{base},{''.join(new)}->{out}")
2027
+ return exprs
1728
2028
 
1729
2029
 
1730
2030
  @decorator
1731
- def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
2031
+ def psgd_update_precond(
2032
+ G: Tensor,
2033
+ precond_lr: float,
2034
+ oq: "TriuOrLine",
2035
+ store_triu_as_line: bool,
2036
+ velocity: Optional[List[Tensor]],
2037
+ beta2: float,
2038
+ ortho_method: Optional[str],
2039
+ V: Tensor,
2040
+ running_lower_bound: List[Tensor],
2041
+ lower_bount_beta: float,
2042
+ power_iter: int,
2043
+ ) -> None:
1732
2044
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1733
- exprA, exprGs, _ = exprs
1734
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1735
- precond_lr = scalar_guard(precond_lr, G)
1736
-
1737
- for q, exprG, o in zip(Q, exprGs, oq):
1738
- term1 = torch.einsum(exprG, A, A)
1739
- term2 = torch.einsum(exprG, conjB, conjB)
1740
- term1, term2 = _compilable_add_sub_(term1, term2)
1741
- norm = term2.norm(float("inf"))
1742
- if q.dim() < 2:
1743
- _compilable_stochastic_multiply_div_(term1, precond_lr, q, norm)
2045
+ Q = _balance_to_triu(oq)
2046
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2047
+ precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
2048
+
2049
+ A, conjB = psgd_calc_A_and_conjB(G, Q, V)
2050
+ terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
2051
+ del A, conjB, V
2052
+ updates = _psgd_default_preconditioner_grad(terms, Q)
2053
+ _psgd_precond_update_(
2054
+ updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2055
+ )
2056
+ return None
2057
+
2058
+
2059
+ @decorator_knowngood
2060
+ def _psgd_precond_update_(
2061
+ matmuled: List[Optional[Tensor]],
2062
+ Q: "TriuOrLine",
2063
+ running_lower_bound: List[Tensor],
2064
+ lower_bount_beta: Tensor,
2065
+ precond_lr: Tensor,
2066
+ store_triu_as_line: bool,
2067
+ power_iter: int,
2068
+ ):
2069
+ for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
2070
+ if isinstance(oq, tuple):
2071
+ oq = oq[1]
2072
+
2073
+ q = promote(oq)
2074
+ if update.ndim < 2:
2075
+ lb = update.norm(float("inf"))
1744
2076
  else:
1745
- lower_bound = psgd_lb(term2, norm)
1746
- _prescale_term_(term1, precond_lr, lower_bound, norm)
1747
- torch.mm(term1, q.to(term1.dtype), out=term1)
1748
- if store_triu_as_line:
1749
- _subtract_from_line_(q, term1)
2077
+ lb = max_singular_value(update, None, power_iter=power_iter)
2078
+ update = promote(update)
2079
+ if store_triu_as_line:
2080
+ update = triu_to_line([update])[0][1]
2081
+
2082
+ lb = promote(lb)
2083
+ lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
2084
+ copy_stochastic_(lb_state, lb)
2085
+ copy_stochastic_(oq, q - update / lb * precond_lr)
2086
+
2087
+
2088
+ @decorator_knowngood
2089
+ def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
2090
+ """
2091
+ I: Identity
2092
+ U: Update / gg / target
2093
+ Q: q, preconditioner
2094
+ scale: scalar scale
2095
+ ---
2096
+ U = T * scale - I
2097
+ F = I - U # = 2I - U * scale
2098
+ O = F @ Q @ F - Q
2099
+ """
2100
+ out = []
2101
+ for gg, q in zip(GG, Q):
2102
+ if gg.ndim < 2:
2103
+ scale = max(1, gg.numel()) / numel
2104
+ target = promote(gg)
2105
+ update = target * scale - 1
2106
+ out.append(q - (1 - update) * q * (1 - update))
1750
2107
  else:
1751
- stochastic_add_(o, term1, -1)
2108
+ scale = gg.size(0) / numel
2109
+ gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
2110
+ update = q - gg @ q @ gg
2111
+ out.append(update + update.T) # make matrix symmetric
2112
+ return out
2113
+
2114
+
2115
+ @decorator
2116
+ def inverse_free_psgd_update_precond(
2117
+ G: Tensor,
2118
+ precond_lr: float,
2119
+ oq: List[Tensor],
2120
+ store_triu_as_line: bool,
2121
+ velocity: Optional[List[Tensor]],
2122
+ beta2: float,
2123
+ ortho_method: Optional[str],
2124
+ V: None,
2125
+ running_lower_bound: List[Tensor],
2126
+ lower_bount_beta: float,
2127
+ power_iter: int,
2128
+ ) -> Tensor:
2129
+ """Update Kronecker product preconditioner Q with pair (V, G)."""
2130
+ assert V is None
2131
+ assert ortho_method is None
2132
+ assert velocity is None
2133
+ del V, ortho_method, velocity
2134
+
2135
+ Q = _balance_to_triu(oq, True)
2136
+ precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
2137
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2138
+
2139
+ G = psgd_precond_grad(G, Q)
2140
+ terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
2141
+ matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
2142
+ _psgd_precond_update_(
2143
+ matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2144
+ )
2145
+ return G
1752
2146
 
1753
2147
 
1754
2148
  @decorator_knowngood
@@ -1785,6 +2179,34 @@ def rmsnorm_clip_(x, clip_at: float = 1.0):
1785
2179
  return _compilable_rmsnorm_clip_(x, clip_at)
1786
2180
 
1787
2181
 
2182
+ @decorator_knowngood
2183
+ def _compilable_global_rmsnorm_clip_(x, clip_at):
2184
+ x = list(map(promote, x))
2185
+ norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
2186
+ norm = norm**0.5
2187
+ norm = norm.clamp(min=clip_at)
2188
+ return torch._foreach_div(x, norm)
2189
+
2190
+
2191
+ @decorator_knowngood
2192
+ def _compilable_global_l2norm_clip_(x, clip_at):
2193
+ x = list(map(promote, x))
2194
+ norm = sum([x.square().sum() for x in x])
2195
+ norm = norm**0.5
2196
+ norm = norm.clamp(min=clip_at)
2197
+ return torch._foreach_div(x, norm)
2198
+
2199
+
2200
+ def global_rmsnorm_clip(x, clip_at: float = 1.0):
2201
+ x = list_guard(x)
2202
+ return _compilable_global_rmsnorm_clip_(x, clip_at)
2203
+
2204
+
2205
+ def global_l2norm_clip(x, clip_at: float = 1.0):
2206
+ x = list_guard(x)
2207
+ return _compilable_global_rmsnorm_clip_(x, clip_at)
2208
+
2209
+
1788
2210
  def rmsnorm_normalize_(x, clip_at: float = 1e-6):
1789
2211
  x = list_guard(x)
1790
2212
  return _compilable_rmsnorm_clip_(x, clip_at)
@@ -1862,6 +2284,17 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1862
2284
  _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1863
2285
 
1864
2286
 
2287
+ @decorator_knowngood
2288
+ def _compilable_weight_decay_to_init_(p, init, weight_decay):
2289
+ _lerp(p, promote(init), 1 - weight_decay)
2290
+
2291
+
2292
+ def weight_decay_to_init_(p, init, weight_decay):
2293
+ p, init = list_guard(p, init)
2294
+ weight_decay = scalar_guard(weight_decay, p[0])
2295
+ _compilable_weight_decay_to_ema_(p, init, weight_decay)
2296
+
2297
+
1865
2298
  @decorator_knowngood
1866
2299
  def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1867
2300
  ema32 = _lerp(ema, p, ema_decay)
@@ -1920,35 +2353,25 @@ def triu_to_line(Q_list: List[Tensor]):
1920
2353
  if q.dim() < 2:
1921
2354
  out.append((None, q))
1922
2355
  else:
1923
- out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
2356
+ out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
1924
2357
  return out
1925
2358
 
1926
2359
 
1927
- def _triu_shape(numel):
1928
- n = int((2 * numel) ** 0.5)
1929
- assert n * (n + 1) == 2 * numel
1930
- return n, n
1931
-
1932
-
1933
- @decorator
1934
- def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
2360
+ @decorator_knowngood
2361
+ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
1935
2362
  new = []
1936
2363
  for shape, q in Q_list:
1937
2364
  if shape is not None:
1938
- shape = _triu_shape(q.numel())
1939
- x = torch.zeros(shape, device=q.device, dtype=q.dtype)
1940
- x[tuple(torch.triu_indices(*shape, device=q.device))] = q
1941
- q = x
2365
+ x, y = torch.triu_indices(*shape, device=q.device)
2366
+ q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
2367
+ q_mat[x, y] = q
2368
+ if symmetric_output:
2369
+ q_mat[y, x] = q
2370
+ q = q_mat
1942
2371
  new.append(q)
1943
2372
  return new
1944
2373
 
1945
2374
 
1946
- def update_triu_(q_state, materialised):
1947
- for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
1948
- assert shape0 == shape1
1949
- copy_stochastic_(q, m)
1950
-
1951
-
1952
2375
  _warned = set()
1953
2376
 
1954
2377
 
@@ -1971,52 +2394,118 @@ def psgd_should_update(
1971
2394
  return int(group[name]) > int(cumulative_prob)
1972
2395
 
1973
2396
 
2397
+ @functools.lru_cache(maxsize=None)
2398
+ def cached_precond_grad_expr(Q_dim, grad_dim):
2399
+ expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
2400
+ expr = ",".join(expr)
2401
+ grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
2402
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
2403
+ return f"{expr},{grad_expr}->{out_expr}"
2404
+
2405
+
1974
2406
  @decorator_knowngood
1975
2407
  def precond_grad_cached_(
1976
- expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
2408
+ ea: Tensor,
2409
+ cached_q: List[Tensor],
2410
+ caution: bool = False,
2411
+ grad: Optional[Tensor] = None,
2412
+ cast: bool = True,
1977
2413
  ):
1978
2414
  if caution:
1979
2415
  ea = _compilable_cautioning(grad, ea)
1980
2416
  md = min_dtype(list(cached_q) + [ea])
1981
2417
  args = [q.to(md) for q in cached_q]
1982
2418
  args = args + [ea.to(md)]
1983
- new = torch.einsum(expr, *args)
2419
+ expr = cached_precond_grad_expr(ndim_tuple(cached_q), grad.ndim)
2420
+ new = compiled_einsum(expr, *args)
1984
2421
  if cast:
1985
2422
  return new.to(ea.dtype)
1986
2423
  return new
1987
2424
 
1988
2425
 
2426
+ TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
2427
+
2428
+
1989
2429
  @decorator_knowngood
1990
- def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1991
- precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
2430
+ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
2431
+ precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
1992
2432
  update_param_(param, precond, lr, decay, caution=False)
1993
2433
 
1994
2434
 
1995
- def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
2435
+ def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
1996
2436
  lr = scalar_guard(lr, param[0])
1997
- _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
2437
+ _compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
2438
+
2439
+
2440
+ @functools.lru_cache(maxsize=None)
2441
+ def precond_grad_expr(Q_dim, grad_dim):
2442
+ expr = [
2443
+ f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
2444
+ ]
2445
+ expr = ",".join(expr)
2446
+ grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
2447
+ out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
2448
+ return f"{expr},{grad_expr}->{out_expr}"
1998
2449
 
1999
2450
 
2000
2451
  @decorator_knowngood
2001
- def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
2452
+ def psgd_precond_grad(
2453
+ ea: Tensor,
2454
+ preconds: TriuOrLine,
2455
+ caution: bool = False,
2456
+ grad: Optional[Tensor] = None,
2457
+ store_triu_as_line: bool = False,
2458
+ symmetric_output: bool = False,
2459
+ ):
2002
2460
  if caution:
2003
2461
  ea = _compilable_cautioning(grad, ea)
2462
+ if store_triu_as_line:
2463
+ preconds = line_to_triu(preconds, symmetric_output)
2004
2464
  md = min_dtype(list(preconds) + [ea])
2005
2465
  args = [q.to(md) for q in preconds]
2006
- args = args + args + [ea.to(md)]
2007
- new = torch.einsum(expr, *args)
2466
+ expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
2467
+ new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
2008
2468
  return new.to(ea.dtype)
2009
2469
 
2010
2470
 
2011
2471
  @decorator_knowngood
2012
- def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
2013
- precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
2472
+ def _compilable_fused_psgd_precond_grad(
2473
+ ea: Tensor,
2474
+ param,
2475
+ lr,
2476
+ grad,
2477
+ decay,
2478
+ caution,
2479
+ preconds: TriuOrLine,
2480
+ store_triu_as_line: bool = False,
2481
+ symmetric_output: bool = False,
2482
+ ):
2483
+ precond = psgd_precond_grad(
2484
+ ea,
2485
+ preconds,
2486
+ caution=caution,
2487
+ grad=grad,
2488
+ store_triu_as_line=store_triu_as_line,
2489
+ symmetric_output=symmetric_output,
2490
+ )
2014
2491
  update_param_(param, precond, lr, decay, caution=False, grad=grad)
2015
2492
 
2016
2493
 
2017
- def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
2494
+ def fused_psgd_precond_grad(
2495
+ ea: Tensor,
2496
+ param,
2497
+ lr,
2498
+ grad,
2499
+ decay,
2500
+ caution,
2501
+ preconds: TriuOrLine,
2502
+ store_triu_as_line: bool = False,
2503
+ symmetric_output: bool = False,
2504
+ ):
2018
2505
  lr = scalar_guard(lr, param[0])
2019
- _compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
2506
+ _compilable_fused_psgd_precond_grad(
2507
+ ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
2508
+ )
2020
2509
 
2021
2510
 
2022
2511
  @decorator_knowngood
@@ -2068,7 +2557,15 @@ def caution(g, update):
2068
2557
  return _compilable_cautioning(g, update)
2069
2558
 
2070
2559
 
2071
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_start=1000):
2560
+ def _inner_precond_update_prob_schedule(
2561
+ n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
2562
+ ):
2563
+ return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
2564
+
2565
+
2566
+ def precond_update_prob_schedule(
2567
+ max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
2568
+ ):
2072
2569
  """Anneal preconditioner update probability during beginning of training.
2073
2570
 
2074
2571
  PSGD benefits from more preconditioner updates at the beginning of training,
@@ -2079,11 +2576,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
2079
2576
  `min_prob` by ~4000 steps. Default settings work very well for most models and
2080
2577
  training regimes.
2081
2578
  """
2082
-
2083
- def _schedule(n):
2084
- return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
2085
-
2086
- return _schedule
2579
+ return functools.partial(
2580
+ _inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
2581
+ )
2087
2582
 
2088
2583
 
2089
2584
  def merge_group(group, *tensors):
@@ -2217,3 +2712,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
2217
2712
  def disable_caution_scaling():
2218
2713
  global _compilable_cautioning
2219
2714
  _compilable_cautioning = _compilable_caution_no_scale
2715
+
2716
+
2717
+ @decorator_knowngood
2718
+ def sam_step(parameters, ball_size, adaptive: bool = True):
2719
+ old_params = []
2720
+ for p in parameters:
2721
+ old_params.append(p.detach().clone())
2722
+ grad = promote(p.grad)
2723
+ if adaptive:
2724
+ grad = grad * promote(p).square()
2725
+ stochastic_add_(p.data, grad, ball_size)
2726
+ p.grad.zero_()
2727
+ return old_params