heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import collections
2
2
  import contextlib
3
+ import enum
3
4
  import functools
4
5
  import gc
5
6
  import inspect
@@ -16,13 +17,29 @@ import torch
16
17
  from torch import Tensor
17
18
  from torch._dynamo.exc import TorchDynamoException
18
19
  from torch.backends import cudnn, opt_einsum
20
+ from torch.nn import functional as F
19
21
  from torch.utils._pytree import tree_map
20
22
 
23
+
24
+ class ZerothPowerMode(enum.Enum):
25
+ newtonschulz = "newtonschulz"
26
+ legacy_newtonschulz = "legacy_newtonschulz"
27
+ qr = "qr"
28
+ svd = "svd"
29
+ legacy_svd = "legacy_svd"
30
+
31
+
32
+ class OrthoScaleMode(enum.Enum):
33
+ none = "none"
34
+ scale = "scale"
35
+ graft = "graft"
36
+
37
+
21
38
  compile_mode = "max-autotune-no-cudagraphs"
22
39
  dynamic = False
23
40
  compile_mode_recommended_to_none = None
24
41
  zeroth_power_mode = "newtonschulz"
25
- precise_zeroth_power_mode = "qr" # or svd
42
+ precise_zeroth_power_mode = "qr"
26
43
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
44
  _cudnn_double_backward_pattern = re.compile(
28
45
  r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
@@ -240,14 +257,14 @@ def eps_sqrt(item, eps):
240
257
 
241
258
  @decorator_knowngood
242
259
  def _compilable_exp_avg_sq_(
243
- state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
260
+ state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
244
261
  ):
245
262
  g32 = promote(grad)
246
263
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
247
264
 
248
265
  denom = [eps_sqrt(d, eps) for d in s32]
249
266
 
250
- if out[0] is None:
267
+ if out is None or out[0] is None:
251
268
  return denom
252
269
 
253
270
  copy_stochastic_list_(out, denom)
@@ -316,8 +333,8 @@ def adaptive_gradient_clipping_(
316
333
  def is_compiling():
317
334
  try:
318
335
  return torch.compiler.is_compiling()
319
- except TorchDynamoException:
320
- return True
336
+ except (TorchDynamoException, AttributeError):
337
+ return False
321
338
 
322
339
 
323
340
  def set_(dst: Tensor, src: Tensor):
@@ -366,12 +383,45 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
366
383
  )
367
384
 
368
385
 
369
- @decorator
386
+ @decorator_knowngood
370
387
  def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
388
+ assert (
389
+ G.ndim >= 2
390
+ ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
391
+ assert steps == 5
392
+ G = G.clone()
393
+ X = G if G.dtype == torch.float64 else stochastic_round_(G)
394
+ if G.size(-2) > G.size(-1):
395
+ X = X.mT
396
+
397
+ # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
398
+ stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
399
+ # Perform the NS iterations
400
+ for a, b, c in [
401
+ (4.0848, -6.8946, 2.9270),
402
+ (3.9505, -6.3029, 2.6377),
403
+ (3.7418, -5.5913, 2.3037),
404
+ (2.8769, -3.1427, 1.2046),
405
+ (2.8366, -3.0525, 1.2012),
406
+ ]:
407
+ A = X @ X.mT
408
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
409
+ X = a * X + B @ X
410
+
411
+ if G.size(-2) > G.size(-1):
412
+ X = X.mT
413
+ return X.to(G.dtype)
414
+
415
+
416
+ @decorator_knowngood
417
+ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
371
418
  assert len(G.shape) == 2
372
419
  a, b, c = (3.4445, -4.7750, 2.0315)
373
- X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
374
- X /= X.norm() + eps # ensure top singular value <= 1
420
+ G = G.clone()
421
+ X = G if G.dtype == torch.float64 else stochastic_round_(G)
422
+
423
+ # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
424
+ stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
375
425
  if G.size(0) > G.size(1):
376
426
  X = X.T
377
427
  for _ in range(steps):
@@ -435,21 +485,30 @@ def _compilable_grafting(magnitude, direction):
435
485
 
436
486
 
437
487
  @decorator_knowngood
438
- def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
439
- if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
488
+ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
489
+ if not isinstance(mode, ZerothPowerMode):
490
+ mode = ZerothPowerMode(mode)
491
+ if not isinstance(scale_mode, ZerothPowerMode):
492
+ scale_mode = OrthoScaleMode(scale_mode)
493
+ if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
440
494
  y = zeropower_via_newtonschulz5(x, 5)
441
- elif mode == "qr":
495
+ elif mode == ZerothPowerMode.legacy_newtonschulz:
496
+ y = legacy_zeropower_via_newtonschulz5(x, 5)
497
+ elif mode == ZerothPowerMode.qr:
442
498
  y = torch.linalg.qr(promote(x)).Q
443
- elif mode == "svd":
444
- u, _s, v = torch.linalg.svd(promote(x))
445
- y = u @ v.T
499
+ elif mode == ZerothPowerMode.svd:
500
+ u, _s, vt = torch.linalg.svd(promote(x))
501
+ y = u @ vt
502
+ elif mode == ZerothPowerMode.legacy_svd:
503
+ u, _s, vt = torch.linalg.svd(promote(x))
504
+ y = u @ vt.T
446
505
  else:
447
506
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
448
- if scale_mode == "none":
507
+ if scale_mode == OrthoScaleMode.none:
449
508
  pass
450
- elif scale_mode == "scale":
451
- y *= max(1, x.size(0) / x.size(1)) ** 0.5
452
- elif scale_mode == "graft":
509
+ elif scale_mode == OrthoScaleMode.scale:
510
+ y *= max(1, x.size(-2) / x.size(-1)) ** 0.5
511
+ elif scale_mode == OrthoScaleMode.graft:
453
512
  y = _compilable_grafting(x, y)
454
513
  else:
455
514
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
@@ -556,10 +615,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
556
615
  except torch.OutOfMemoryError:
557
616
  if m.device.type == "cpu":
558
617
  raise
559
- else:
618
+ if torch.cuda.is_available():
619
+ torch.cuda.synchronize(m.device)
620
+ clean()
621
+ m = m.cpu()
622
+ except RuntimeError as e:
623
+ if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
624
+ torch.cuda.synchronize(m.device)
625
+ clean()
560
626
  m = m.cpu()
561
- except RuntimeError: # failed to compute eigenvalues
562
- if m.dtype != torch.double:
627
+ elif m.dtype != torch.double:
563
628
  m = m.double()
564
629
  elif eps < max_eps:
565
630
  eps = eps ** (2 / 3)
@@ -658,7 +723,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
658
723
  copy_stochastic_(x_, x32 + y32 * alpha)
659
724
 
660
725
 
661
- def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
726
+ def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
662
727
  x, y = broadcastable_list_guard(x, y)
663
728
  alpha = scalar_guard(alpha, x[0])
664
729
  _compilable_stochastic_add_(x, y, alpha)
@@ -672,7 +737,9 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
672
737
  copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
673
738
 
674
739
 
675
- def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
740
+ def stochastic_add_divide_(
741
+ x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
742
+ ):
676
743
  x, y = broadcastable_list_guard(x, y)
677
744
  alpha, divisor = scalar_guard(alpha, divisor, x[0])
678
745
  _compilable_stochastic_add_divide_(x, y, alpha, divisor)
@@ -686,11 +753,25 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
686
753
  copy_stochastic_(x_, x32 * y32)
687
754
 
688
755
 
689
- def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
756
+ def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
690
757
  x, y = broadcastable_list_guard(x, y)
691
758
  _compilable_stochastic_multiply_(x, y)
692
759
 
693
760
 
761
+ @decorator_knowngood
762
+ def _compilable_stochastic_divide_with_eps_(x: List[Tensor], y: List[Tensor], eps: Tensor):
763
+ for x_, y_ in zip(x, y):
764
+ x32 = promote(x_)
765
+ y32 = promote(y_)
766
+ copy_stochastic_(x_, x32 / (y32 + eps))
767
+
768
+
769
+ def stochastic_divide_with_eps_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, eps: float):
770
+ x, y = broadcastable_list_guard(x, y)
771
+ eps = scalar_guard(eps, y[0])
772
+ _compilable_stochastic_divide_with_eps_(x, y, eps)
773
+
774
+
694
775
  @decorator
695
776
  def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
696
777
  """
@@ -832,6 +913,10 @@ class ExactHVPFailed(ValueError):
832
913
  use_default = object()
833
914
 
834
915
 
916
+ def _tensor_key(x: Tensor):
917
+ return x.data_ptr(), x.numel(), x.dtype, x.device
918
+
919
+
835
920
  class StatefulOptimizer(torch.optim.Optimizer):
836
921
  """
837
922
  finite_differences saves memory, but needs more compute. (Alternative is true HVP)
@@ -874,7 +959,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
874
959
 
875
960
  self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
876
961
  self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
877
- self._init_mapping()
878
962
 
879
963
  def _store_stats(self, state_dict: dict[str, any]):
880
964
  state_dict["heavyball"] = {
@@ -905,7 +989,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
905
989
  def state_(self, arg: Tensor, fail: bool = True):
906
990
  if not fail and arg not in self.mapping:
907
991
  return {}
908
- state_param, index = self.mapping_inverse[arg]
992
+ if _tensor_key(arg) not in self.mapping_inverse:
993
+ self._init_mapping()
994
+ state_param, index = self.mapping_inverse[_tensor_key(arg)]
909
995
  if state_param not in self.state:
910
996
  self.state[state_param] = collections.defaultdict(dict)
911
997
  return self.state[state_param][index]
@@ -928,7 +1014,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
928
1014
  if p not in self.mapping:
929
1015
  self.mapping[p] = p_views = merge_group(group, p)
930
1016
  for i, pv in enumerate(p_views):
931
- self.mapping_inverse[pv] = (p, i)
1017
+ self.mapping_inverse[_tensor_key(pv)] = (p, i)
932
1018
 
933
1019
  def split_p_and_g_in_group(
934
1020
  self,
@@ -949,12 +1035,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
949
1035
  yield p, grad
950
1036
  continue
951
1037
 
952
- if p in self.mapping:
953
- p_views = self.mapping[p]
954
- else:
955
- self.mapping[p] = p_views = merge_group(group, p)
956
- for i, pv in enumerate(p_views):
957
- self.mapping_inverse[pv] = (p, i)
1038
+ self.mapping[p] = p_views = merge_group(group, p)
1039
+ for i, pv in enumerate(p_views):
1040
+ self.mapping_inverse[_tensor_key(pv)] = (p, i)
958
1041
 
959
1042
  vector = getattr(p, "vector", None)
960
1043
  hessian_vector = getattr(p, "hessian_vector", None)
@@ -1199,17 +1282,53 @@ def _compilable_adam_(
1199
1282
 
1200
1283
 
1201
1284
  def adam_(
1285
+ exp_avg: List[Tensor] | Tensor,
1286
+ exp_avg_sq: List[Tensor] | Tensor,
1287
+ grad: List[Tensor] | Tensor,
1288
+ beta1: float,
1289
+ beta2: float,
1290
+ step: int,
1291
+ eps: float = 1e-8,
1292
+ ) -> List[Tensor]:
1293
+ exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
1294
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1295
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1296
+ return grad
1297
+
1298
+
1299
+ @decorator_knowngood
1300
+ def _compilable_unscaled_adam_(
1202
1301
  exp_avg: List[Tensor],
1203
1302
  exp_avg_sq: List[Tensor],
1204
1303
  grad: List[Tensor],
1304
+ beta1: Tensor,
1305
+ beta2: Tensor,
1306
+ step: Tensor,
1307
+ eps: Tensor,
1308
+ ):
1309
+ beta1 = beta_debias(beta1, step)
1310
+ beta2 = beta_debias(beta2, step)
1311
+
1312
+ g32 = list(map(promote, grad))
1313
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
1314
+ g32 = torch._foreach_div(g32, denom)
1315
+ exp_avg32 = _lerp(exp_avg, g32, beta1)
1316
+ u32 = torch._foreach_mul(exp_avg32, denom)
1317
+ copy_stochastic_list_(grad, u32)
1318
+
1319
+
1320
+ def unscaled_adam_(
1321
+ exp_avg: List[Tensor] | Tensor,
1322
+ exp_avg_sq: List[Tensor] | Tensor,
1323
+ grad: List[Tensor] | Tensor,
1205
1324
  beta1: float,
1206
1325
  beta2: float,
1207
1326
  step: int,
1208
1327
  eps: float = 1e-8,
1209
- ):
1328
+ ) -> List[Tensor]:
1210
1329
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
1211
1330
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1212
- _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1331
+ _compilable_unscaled_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1213
1332
  return grad
1214
1333
 
1215
1334
 
@@ -1253,7 +1372,7 @@ def fused_adam_(
1253
1372
  caution: bool,
1254
1373
  ):
1255
1374
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
1256
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
1375
+ beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
1257
1376
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
1258
1377
 
1259
1378
 
@@ -1332,7 +1451,7 @@ def fused_laprop_(
1332
1451
  eps: float = 1e-8,
1333
1452
  ):
1334
1453
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
1335
- beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
1454
+ beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
1336
1455
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
1337
1456
 
1338
1457
 
@@ -1351,7 +1470,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
1351
1470
 
1352
1471
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
1353
1472
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
1354
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
1473
+ beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
1355
1474
  _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
1356
1475
 
1357
1476
 
@@ -1381,11 +1500,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
1381
1500
 
1382
1501
 
1383
1502
  @decorator_knowngood
1384
- def stochastic_round_(ref: Tensor, source: Tensor):
1385
- if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1386
- return source
1387
- if ref.dtype != torch.bfloat16:
1388
- return source.to(ref.dtype)
1503
+ def stochastic_round_(ref: Tensor, source: Tensor | None = None):
1504
+ if source is not None:
1505
+ if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1506
+ return source
1507
+ if ref.dtype != torch.bfloat16:
1508
+ return source.to(ref.dtype)
1509
+ else:
1510
+ source = ref
1511
+ source = source.float()
1389
1512
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1390
1513
  result.add_(source.view(dtype=torch.int32))
1391
1514
  result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
@@ -1908,7 +2031,9 @@ def ndim_tuple(Q: list[Tensor]) -> tuple:
1908
2031
  return tuple(q.ndim for q in Q)
1909
2032
 
1910
2033
 
1911
- def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
2034
+ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
2035
+ if conjB is None:
2036
+ conjB = torch.randn_like(G)
1912
2037
  exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
1913
2038
  A = casted_einsum(exprA, *Q, G)
1914
2039
  solve = torch.compiler.disable(torch.linalg.solve_triangular)
@@ -1940,24 +2065,35 @@ def max_singular_value_exact(A, use_lobpcg: bool = False):
1940
2065
  eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
1941
2066
  return eigval[0].sqrt()
1942
2067
  else:
1943
- return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
1944
- except torch.linalg.LinAlgError:
1945
- return torch.zeros((), device=A.device, dtype=A.dtype)
2068
+ return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
2069
+ except (torch.linalg.LinAlgError, RuntimeError):
2070
+ return max_singular_value_power_iter(promote(A), iterations=2)
1946
2071
 
1947
2072
 
1948
2073
  @decorator_knowngood
1949
- def max_singular_value_power_iter(A: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
2074
+ def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
1950
2075
  """
1951
2076
  Rayleigh quotient of row with the largest norm + optional power iterations
1952
2077
  """
1953
- x_norm, max_idx = A.norm(dim=1).max(dim=0)
1954
- x = A.index_select(0, max_idx).flatten().contiguous()
1955
- A = A / x_norm
1956
- x = x / x_norm
1957
- for _ in range(iterations):
1958
- x = A.T.mv(A.mv(x)) # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
1959
- x = x / x.norm()
1960
- return (x @ A.T.mv(A.mv(x))).sqrt() * x_norm
2078
+ x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
2079
+ x_norm = promote(x_norm)
2080
+
2081
+ def _inner():
2082
+ A = A_outer
2083
+ x = A.index_select(0, max_idx).flatten().contiguous()
2084
+ A = stochastic_round_(A / x_norm)
2085
+ x = x / x_norm
2086
+
2087
+ def _mv(x):
2088
+ return promote(A.T.mv(A.mv(stochastic_round_(x))))
2089
+
2090
+ for _ in range(iterations):
2091
+ # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
2092
+ x = F.normalize(_mv(x), dim=0)
2093
+ out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
2094
+ return out.squeeze().clone()
2095
+
2096
+ return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
1961
2097
 
1962
2098
 
1963
2099
  @decorator_knowngood
@@ -1974,33 +2110,81 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
1974
2110
  return sketch_norm * max_abs
1975
2111
 
1976
2112
 
2113
+ def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
2114
+ if A.ndim <= 2:
2115
+ return max_singular_value(A, max_svd, use_cholesky, power_iter)
2116
+
2117
+ base = einsum_base[: A.ndim]
2118
+ A16 = stochastic_round_(A)
2119
+ squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
2120
+ svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
2121
+ svds = torch.stack(svds)
2122
+ return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
2123
+
2124
+
1977
2125
  @decorator_knowngood
1978
- def max_singular_value(
1979
- A: Tensor, max_abs: Optional[Tensor], max_svd: int = 32, use_cholesky: bool = False, power_iter: int = 0
1980
- ) -> Tensor:
2126
+ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
2127
+ if A.ndim < 2:
2128
+ return A.abs().max()
2129
+ if A.ndim > 2:
2130
+ raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
1981
2131
  if min(A.shape) <= max_svd:
1982
2132
  return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
1983
2133
  if use_cholesky or power_iter < 0:
1984
- return max_singular_value_cholesky(A, max_abs)
2134
+ return max_singular_value_cholesky(A)
1985
2135
  return max_singular_value_power_iter(A, None, iterations=power_iter)
1986
2136
 
1987
2137
 
1988
2138
  @decorator_knowngood
1989
- def _psgd_default_preconditioner_grad(
1990
- terms: List[Tuple[Tensor, Tensor]],
1991
- Q: List[Tensor],
1992
- ) -> List[Tensor]:
1993
- out = []
1994
- for q, (x, y) in zip(Q, terms):
1995
- x = promote(x)
1996
- y = promote(y)
1997
- update = x - y
1998
- if q.ndim < 2:
1999
- update = q * update
2000
- else:
2001
- update = (q @ update).triu()
2002
- out.append(update)
2003
- return out
2139
+ def clamped_max_singular_value(
2140
+ A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
2141
+ ) -> Tensor:
2142
+ norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
2143
+ out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
2144
+ return out.clamp(min=min)
2145
+
2146
+
2147
+ @decorator_knowngood
2148
+ def min_singular_value(
2149
+ A: Tensor,
2150
+ power_iter: int = 5,
2151
+ safety: float = 1.05,
2152
+ max_svd: int = 32,
2153
+ ):
2154
+ if A.ndim < 2:
2155
+ return A.abs().min()
2156
+
2157
+ n = A.size(0)
2158
+ if n <= max_svd:
2159
+ try:
2160
+ eigs = torch.linalg.eigvalsh(promote(A))
2161
+ return eigs.min().to(A.dtype)
2162
+ except torch.linalg.LinAlgError:
2163
+ pass
2164
+
2165
+ lambda_max_hat = max_singular_value(A, power_iter=power_iter)
2166
+ lambda_upper = lambda_max_hat * safety
2167
+
2168
+ row_norms = A.norm(dim=1)
2169
+ norm, idx = row_norms.min(dim=0)
2170
+ v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
2171
+
2172
+ v = v / promote(v.norm())
2173
+ for _ in range(power_iter):
2174
+ v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
2175
+ v = v / promote(v.norm())
2176
+ mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
2177
+
2178
+ lambda_min_hat = lambda_upper - mu_hat
2179
+
2180
+ def _approx():
2181
+ mu = A.trace() / n
2182
+ sigma_square = A.square().sum() / n - mu**2
2183
+ return mu - (sigma_square / (n - 1)).sqrt()
2184
+
2185
+ return cond(
2186
+ (~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
2187
+ ).squeeze()
2004
2188
 
2005
2189
 
2006
2190
  @decorator_knowngood
@@ -2027,6 +2211,107 @@ def calcG_expr(q_dim, g_dim):
2027
2211
  return exprs
2028
2212
 
2029
2213
 
2214
+ def eye_like(x: Tensor):
2215
+ if x.ndim < 2:
2216
+ return torch.ones_like(x)
2217
+ assert x.ndim == 2
2218
+ assert x.size(0) == x.size(1)
2219
+ return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
2220
+
2221
+
2222
+ @decorator_knowngood
2223
+ def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
2224
+ """
2225
+ Idea:
2226
+ G should be zeroth power. So, all Qs together should approximate the G's inverse.
2227
+ Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
2228
+ Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
2229
+ then backprop to L/R jointly.
2230
+ This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
2231
+
2232
+ `psgd_precond_grad` computes LGR for the general (n-dimensional) case
2233
+ `exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
2234
+ Args:
2235
+ G: Gradient that should be orthogonalized
2236
+ Q: List of preconditioner tensors.
2237
+
2238
+ Returns:
2239
+ - List of gradients with respect to Q (d_Q).
2240
+ """
2241
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2242
+
2243
+ G16 = stochastic_round_(G)
2244
+ Q16 = [stochastic_round_(q) for q in Q]
2245
+ P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
2246
+
2247
+ d_P = torch.zeros_like(G)
2248
+ base = einsum_base[: G.ndim]
2249
+ for i, exprG in enumerate(exprGs):
2250
+ pp = compiled_einsum(exprG, P, P)
2251
+ error = pp - eye_like(pp)
2252
+ dim = einsum_base[i]
2253
+ if pp.ndim == 2:
2254
+ new = dim.upper()
2255
+ prec = f"{new}{dim}"
2256
+ else:
2257
+ new = dim
2258
+ prec = dim
2259
+ d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
2260
+
2261
+ d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
2262
+ grads = []
2263
+ for i, exprG in enumerate(exprGs):
2264
+ new_q = Q16[:]
2265
+ new_q[i] = eye_like(new_q[i])
2266
+ pq = psgd_precond_grad(G16, new_q)
2267
+ grad = compiled_einsum(exprG, pq, d_P)
2268
+ if grad.ndim == 2:
2269
+ grad = (grad + grad.T) / 2
2270
+ grads.append(grad)
2271
+
2272
+ return grads, P.to(G.dtype)
2273
+
2274
+
2275
+ def _inverse_initial_guess(gg):
2276
+ n = gg.shape[0]
2277
+
2278
+ sigma_max = promote(gg.norm())
2279
+
2280
+ trace_gg = promote(torch.trace(gg))
2281
+ sigma_min_approx = trace_gg / (n * sigma_max)
2282
+
2283
+ return sigma_max, sigma_min_approx
2284
+
2285
+
2286
+ @decorator_knowngood
2287
+ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
2288
+ k = torch.arange(degree, dtype=torch.float64, device=device)
2289
+ rotation = (2 * k + 1) * math.pi / (2 * degree)
2290
+ f = (rotation.cos() + 1 + eps) ** -0.5
2291
+ rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
2292
+ coeff0 = f.sum() / degree
2293
+ coeffs = f @ rotation * 2 / degree
2294
+ return coeff0.float(), coeffs.float()
2295
+
2296
+
2297
+ @decorator_knowngood
2298
+ def _psgd_default_preconditioner_grad(
2299
+ terms: List[Tuple[Tensor, Tensor]],
2300
+ Q: List[Tensor],
2301
+ ) -> List[Tensor]:
2302
+ out = []
2303
+ for q, (x, y) in zip(Q, terms):
2304
+ x = promote(x)
2305
+ y = promote(y)
2306
+ update = x - y
2307
+ if q.ndim < 2:
2308
+ update = q * update
2309
+ else:
2310
+ update = (q @ update).triu()
2311
+ out.append(update)
2312
+ return out
2313
+
2314
+
2030
2315
  @decorator
2031
2316
  def psgd_update_precond(
2032
2317
  G: Tensor,
@@ -2056,6 +2341,102 @@ def psgd_update_precond(
2056
2341
  return None
2057
2342
 
2058
2343
 
2344
+ @decorator_knowngood
2345
+ def bf16_matmul(x: Tensor, y: Tensor):
2346
+ return (promote(x) @ promote(y)).to(x.dtype)
2347
+
2348
+
2349
+ def if_iscompiling(fn):
2350
+ base = getattr(torch, fn.__name__, None)
2351
+
2352
+ def _fn(x):
2353
+ if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2354
+ return base(x)
2355
+ return fn(x)
2356
+
2357
+ return _fn
2358
+
2359
+
2360
+ @if_iscompiling
2361
+ def while_loop(cond, body, state):
2362
+ """
2363
+ dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
2364
+ useful for debugging
2365
+ """
2366
+ while cond(*state).item():
2367
+ state = body(*state)
2368
+ return state
2369
+
2370
+
2371
+ @if_iscompiling
2372
+ def cond(cond, true_fn, false_fn):
2373
+ """
2374
+ dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
2375
+ useful for debugging
2376
+ """
2377
+
2378
+ if cond.item():
2379
+ return true_fn()
2380
+ return false_fn()
2381
+
2382
+
2383
+ @decorator_knowngood
2384
+ def _householder_vec_e1_to_v(v: Tensor, eps: float = 1e-12) -> Tensor:
2385
+ """
2386
+ Return w such that H = I - 2 w w^T is orthogonal and H e1 = v (v unit).
2387
+ Applying from the right: G @ H = G - 2 (G @ w) w^T.
2388
+ If v is (numerically) e1, returns w=0 and H=I.
2389
+ """
2390
+ v = v / v.norm().clamp(min=eps)
2391
+ e1 = torch.zeros_like(v)
2392
+ e1[0] = 1.0
2393
+ w = e1 - v
2394
+ return w / w.norm().clamp(min=eps)
2395
+
2396
+
2397
+ @decorator_knowngood
2398
+ def eigvecs_product_rank1(
2399
+ G: Tensor, v: Tensor, w: Optional[Tensor] = None, eps: float = 1e-12
2400
+ ) -> Tuple[Tensor, Tensor]:
2401
+ """
2402
+ Compute Y = G @ V where V is an eigenvector matrix for P = λ I + σ v v^T,
2403
+ using the Householder reflector with first column v. Never materializes V.
2404
+
2405
+ Args:
2406
+ G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
2407
+ v: shape (d,) — current unit direction (top eigenvector of P).
2408
+ w: optional Householder vector w; pass to reuse across calls.
2409
+
2410
+ Returns:
2411
+ (Y, w) where:
2412
+ Y has shape (..., d) and equals G @ eigenvectors(P),
2413
+ w is the Householder vector you can cache & reuse.
2414
+ """
2415
+ if w is None:
2416
+ w = _householder_vec_e1_to_v(v, eps)
2417
+ Y = G - 2.0 * compiled_einsum("...i,i,j->...j", G, w, w)
2418
+ return Y, w
2419
+
2420
+
2421
+ @decorator_knowngood
2422
+ def oja_update(v: Tensor, g: Tensor, lr: float = 1e-2, eps: float = 1e-12) -> Tensor:
2423
+ """
2424
+ One Oja step to track the top eigendirection of the gradient covariance.
2425
+ v <- v + lr * ((g^T v) g - (g^T v)^2 v); then renormalize.
2426
+ """
2427
+ gv = g @ v
2428
+ v = v + lr * (gv * g - (gv * gv) * v)
2429
+ return v / v.norm().clamp(min=eps)
2430
+
2431
+
2432
+ def cond_n(cond_val: Tensor, *fns):
2433
+ fns = list(fns)
2434
+ fn = fns.pop(0)
2435
+ if not fns:
2436
+ return fn
2437
+ return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
2438
+
2439
+
2059
2440
  @decorator_knowngood
2060
2441
  def _psgd_precond_update_(
2061
2442
  matmuled: List[Optional[Tensor]],
@@ -2074,7 +2455,7 @@ def _psgd_precond_update_(
2074
2455
  if update.ndim < 2:
2075
2456
  lb = update.norm(float("inf"))
2076
2457
  else:
2077
- lb = max_singular_value(update, None, power_iter=power_iter)
2458
+ lb = max_singular_value(update, power_iter=power_iter)
2078
2459
  update = promote(update)
2079
2460
  if store_triu_as_line:
2080
2461
  update = triu_to_line([update])[0][1]
@@ -2146,70 +2527,83 @@ def inverse_free_psgd_update_precond(
2146
2527
 
2147
2528
 
2148
2529
  @decorator_knowngood
2149
- def _compilable_l2_clip_(x, clip_at):
2150
- ref = x
2151
- x = list(map(promote, x))
2152
- norm = torch._foreach_norm(x)
2153
- torch._foreach_maximum_(norm, clip_at)
2154
- out = torch._foreach_div(x, norm)
2155
- return stochastic_round_list_(ref, out)
2530
+ def _clip(x, norm, clip_at, eps=1e-8):
2531
+ x32 = promote(x)
2532
+ # (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
2533
+ norm = clip_at / norm.clamp(min=max(clip_at, eps))
2534
+ x32 = x32 * norm
2535
+ copy_stochastic_(x, x32)
2536
+
2537
+
2538
+ @decorator_knowngood
2539
+ def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
2540
+ for x in xs:
2541
+ _clip(x, promote(x).norm(), clip_at, eps)
2156
2542
 
2157
2543
 
2158
2544
  def l2_normalization_(x, clip_at: float = 1e-8):
2159
2545
  x = list_guard(x)
2160
- return _compilable_l2_clip_(x, clip_at)
2546
+ _compilable_l2_clip_(x, clip_at)
2547
+ return x
2161
2548
 
2162
2549
 
2163
2550
  def l2_clip_(x, clip_at: float = 1.0):
2164
2551
  x = list_guard(x)
2165
- return _compilable_l2_clip_(x, clip_at)
2552
+ _compilable_l2_clip_(x, clip_at)
2553
+ return x
2166
2554
 
2167
2555
 
2168
2556
  @decorator_knowngood
2169
- def _compilable_rmsnorm_clip_(x, clip_at):
2170
- x = list(map(promote, x))
2171
- norm = torch._foreach_norm(x)
2172
- norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
2173
- torch._foreach_maximum_(norm, clip_at)
2174
- return torch._foreach_div(x, norm)
2557
+ def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
2558
+ for x in xs:
2559
+ _clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
2175
2560
 
2176
2561
 
2177
2562
  def rmsnorm_clip_(x, clip_at: float = 1.0):
2178
2563
  x = list_guard(x)
2179
- return _compilable_rmsnorm_clip_(x, clip_at)
2180
-
2181
-
2182
- @decorator_knowngood
2183
- def _compilable_global_rmsnorm_clip_(x, clip_at):
2184
- x = list(map(promote, x))
2185
- norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
2186
- norm = norm**0.5
2187
- norm = norm.clamp(min=clip_at)
2188
- return torch._foreach_div(x, norm)
2564
+ _compilable_rmsnorm_clip_(x, clip_at)
2565
+ return x
2189
2566
 
2190
2567
 
2191
2568
  @decorator_knowngood
2192
- def _compilable_global_l2norm_clip_(x, clip_at):
2193
- x = list(map(promote, x))
2194
- norm = sum([x.square().sum() for x in x])
2195
- norm = norm**0.5
2196
- norm = norm.clamp(min=clip_at)
2197
- return torch._foreach_div(x, norm)
2569
+ def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
2570
+ norm = 0
2571
+ numel = sum([i.numel() for i in x])
2572
+ for i in x:
2573
+ norm += promote(i).square().sum()
2574
+ norm = (norm / numel) ** 0.5
2575
+ scalar = clip_at / norm.clamp(min=max(clip_at, eps))
2576
+ stochastic_multiply_(x, scalar)
2198
2577
 
2199
2578
 
2200
2579
  def global_rmsnorm_clip(x, clip_at: float = 1.0):
2201
2580
  x = list_guard(x)
2202
- return _compilable_global_rmsnorm_clip_(x, clip_at)
2581
+ clip_at = scalar_guard(clip_at, x[0])
2582
+ _compilable_global_rmsnorm_clip_(x, clip_at)
2583
+ return x
2584
+
2585
+
2586
+ @decorator_knowngood
2587
+ def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
2588
+ norm = 0
2589
+ for i in x:
2590
+ norm += promote(i).square().sum()
2591
+ norm = norm**0.5
2592
+ scalar = clip_at / norm.clamp(min=max(clip_at, eps))
2593
+ stochastic_multiply_(x, scalar)
2203
2594
 
2204
2595
 
2205
2596
  def global_l2norm_clip(x, clip_at: float = 1.0):
2206
2597
  x = list_guard(x)
2207
- return _compilable_global_rmsnorm_clip_(x, clip_at)
2598
+ clip_at = scalar_guard(clip_at, x[0])
2599
+ _compilable_global_l2norm_clip_(x, clip_at)
2600
+ return x
2208
2601
 
2209
2602
 
2210
2603
  def rmsnorm_normalize_(x, clip_at: float = 1e-6):
2211
2604
  x = list_guard(x)
2212
- return _compilable_rmsnorm_clip_(x, clip_at)
2605
+ _compilable_rmsnorm_clip_(x, clip_at)
2606
+ return x
2213
2607
 
2214
2608
 
2215
2609
  @decorator_knowngood
@@ -2284,17 +2678,6 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
2284
2678
  _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
2285
2679
 
2286
2680
 
2287
- @decorator_knowngood
2288
- def _compilable_weight_decay_to_init_(p, init, weight_decay):
2289
- _lerp(p, promote(init), 1 - weight_decay)
2290
-
2291
-
2292
- def weight_decay_to_init_(p, init, weight_decay):
2293
- p, init = list_guard(p, init)
2294
- weight_decay = scalar_guard(weight_decay, p[0])
2295
- _compilable_weight_decay_to_ema_(p, init, weight_decay)
2296
-
2297
-
2298
2681
  @decorator_knowngood
2299
2682
  def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
2300
2683
  ema32 = _lerp(ema, p, ema_decay)
@@ -2416,7 +2799,7 @@ def precond_grad_cached_(
2416
2799
  md = min_dtype(list(cached_q) + [ea])
2417
2800
  args = [q.to(md) for q in cached_q]
2418
2801
  args = args + [ea.to(md)]
2419
- expr = cached_precond_grad_expr(ndim_tuple(cached_q), grad.ndim)
2802
+ expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
2420
2803
  new = compiled_einsum(expr, *args)
2421
2804
  if cast:
2422
2805
  return new.to(ea.dtype)
@@ -2433,7 +2816,7 @@ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, c
2433
2816
 
2434
2817
 
2435
2818
  def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
2436
- lr = scalar_guard(lr, param[0])
2819
+ lr, decay = scalar_guard(lr, decay, param[0])
2437
2820
  _compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
2438
2821
 
2439
2822
 
@@ -2502,7 +2885,7 @@ def fused_psgd_precond_grad(
2502
2885
  store_triu_as_line: bool = False,
2503
2886
  symmetric_output: bool = False,
2504
2887
  ):
2505
- lr = scalar_guard(lr, param[0])
2888
+ lr, decay = scalar_guard(lr, decay, param[0])
2506
2889
  _compilable_fused_psgd_precond_grad(
2507
2890
  ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
2508
2891
  )