heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import collections
2
2
  import contextlib
3
+ import enum
3
4
  import functools
4
5
  import gc
5
6
  import inspect
@@ -16,13 +17,29 @@ import torch
16
17
  from torch import Tensor
17
18
  from torch._dynamo.exc import TorchDynamoException
18
19
  from torch.backends import cudnn, opt_einsum
20
+ from torch.nn import functional as F
19
21
  from torch.utils._pytree import tree_map
20
22
 
23
+
24
+ class ZerothPowerMode(enum.Enum):
25
+ newtonschulz = "newtonschulz"
26
+ legacy_newtonschulz = "legacy_newtonschulz"
27
+ qr = "qr"
28
+ svd = "svd"
29
+ legacy_svd = "legacy_svd"
30
+
31
+
32
+ class OrthoScaleMode(enum.Enum):
33
+ none = "none"
34
+ scale = "scale"
35
+ graft = "graft"
36
+
37
+
21
38
  compile_mode = "max-autotune-no-cudagraphs"
22
39
  dynamic = False
23
40
  compile_mode_recommended_to_none = None
24
41
  zeroth_power_mode = "newtonschulz"
25
- precise_zeroth_power_mode = "qr" # or svd
42
+ precise_zeroth_power_mode = "qr"
26
43
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
44
  _cudnn_double_backward_pattern = re.compile(
28
45
  r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
@@ -240,14 +257,14 @@ def eps_sqrt(item, eps):
240
257
 
241
258
  @decorator_knowngood
242
259
  def _compilable_exp_avg_sq_(
243
- state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
260
+ state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
244
261
  ):
245
262
  g32 = promote(grad)
246
263
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
247
264
 
248
265
  denom = [eps_sqrt(d, eps) for d in s32]
249
266
 
250
- if out[0] is None:
267
+ if out is None or out[0] is None:
251
268
  return denom
252
269
 
253
270
  copy_stochastic_list_(out, denom)
@@ -316,8 +333,8 @@ def adaptive_gradient_clipping_(
316
333
  def is_compiling():
317
334
  try:
318
335
  return torch.compiler.is_compiling()
319
- except TorchDynamoException:
320
- return True
336
+ except (TorchDynamoException, AttributeError):
337
+ return False
321
338
 
322
339
 
323
340
  def set_(dst: Tensor, src: Tensor):
@@ -366,12 +383,42 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
366
383
  )
367
384
 
368
385
 
369
- @decorator
386
+ @decorator_knowngood
370
387
  def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
388
+ assert (
389
+ G.ndim >= 2
390
+ ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
391
+ assert steps == 5
392
+ G = G.clone()
393
+ X = G if G.dtype == torch.float64 else stochastic_round_(G)
394
+ if G.size(-2) > G.size(-1):
395
+ X = X.mT
396
+
397
+ stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
398
+ # Perform the NS iterations
399
+ for a, b, c in [
400
+ (4.0848, -6.8946, 2.9270),
401
+ (3.9505, -6.3029, 2.6377),
402
+ (3.7418, -5.5913, 2.3037),
403
+ (2.8769, -3.1427, 1.2046),
404
+ (2.8366, -3.0525, 1.2012),
405
+ ]:
406
+ A = X @ X.mT
407
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
408
+ X = a * X + B @ X
409
+
410
+ if G.size(-2) > G.size(-1):
411
+ X = X.mT
412
+ return X.to(G.dtype)
413
+
414
+
415
+ @decorator_knowngood
416
+ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
371
417
  assert len(G.shape) == 2
372
418
  a, b, c = (3.4445, -4.7750, 2.0315)
373
- X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
374
- X /= X.norm() + eps # ensure top singular value <= 1
419
+ G = G.clone()
420
+ X = G if G.dtype == torch.float64 else stochastic_round_(G)
421
+ stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
375
422
  if G.size(0) > G.size(1):
376
423
  X = X.T
377
424
  for _ in range(steps):
@@ -435,21 +482,30 @@ def _compilable_grafting(magnitude, direction):
435
482
 
436
483
 
437
484
  @decorator_knowngood
438
- def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
439
- if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
485
+ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
486
+ if not isinstance(mode, ZerothPowerMode):
487
+ mode = ZerothPowerMode(mode)
488
+ if not isinstance(scale_mode, ZerothPowerMode):
489
+ scale_mode = OrthoScaleMode(scale_mode)
490
+ if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
440
491
  y = zeropower_via_newtonschulz5(x, 5)
441
- elif mode == "qr":
492
+ elif mode == ZerothPowerMode.legacy_newtonschulz:
493
+ y = legacy_zeropower_via_newtonschulz5(x, 5)
494
+ elif mode == ZerothPowerMode.qr:
442
495
  y = torch.linalg.qr(promote(x)).Q
443
- elif mode == "svd":
444
- u, _s, v = torch.linalg.svd(promote(x))
445
- y = u @ v.T
496
+ elif mode == ZerothPowerMode.svd:
497
+ u, _s, vt = torch.linalg.svd(promote(x))
498
+ y = u @ vt
499
+ elif mode == ZerothPowerMode.legacy_svd:
500
+ u, _s, vt = torch.linalg.svd(promote(x))
501
+ y = u @ vt.T
446
502
  else:
447
503
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
448
- if scale_mode == "none":
504
+ if scale_mode == OrthoScaleMode.none:
449
505
  pass
450
- elif scale_mode == "scale":
451
- y *= max(1, x.size(0) / x.size(1)) ** 0.5
452
- elif scale_mode == "graft":
506
+ elif scale_mode == OrthoScaleMode.scale:
507
+ y *= max(1, x.size(-2) / x.size(-1)) ** 0.5
508
+ elif scale_mode == OrthoScaleMode.graft:
453
509
  y = _compilable_grafting(x, y)
454
510
  else:
455
511
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
@@ -556,10 +612,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
556
612
  except torch.OutOfMemoryError:
557
613
  if m.device.type == "cpu":
558
614
  raise
559
- else:
615
+ if torch.cuda.is_available():
616
+ torch.cuda.synchronize(m.device)
617
+ clean()
618
+ m = m.cpu()
619
+ except RuntimeError as e:
620
+ if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
621
+ torch.cuda.synchronize(m.device)
622
+ clean()
560
623
  m = m.cpu()
561
- except RuntimeError: # failed to compute eigenvalues
562
- if m.dtype != torch.double:
624
+ elif m.dtype != torch.double:
563
625
  m = m.double()
564
626
  elif eps < max_eps:
565
627
  eps = eps ** (2 / 3)
@@ -658,7 +720,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
658
720
  copy_stochastic_(x_, x32 + y32 * alpha)
659
721
 
660
722
 
661
- def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
723
+ def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
662
724
  x, y = broadcastable_list_guard(x, y)
663
725
  alpha = scalar_guard(alpha, x[0])
664
726
  _compilable_stochastic_add_(x, y, alpha)
@@ -672,7 +734,9 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
672
734
  copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
673
735
 
674
736
 
675
- def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
737
+ def stochastic_add_divide_(
738
+ x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
739
+ ):
676
740
  x, y = broadcastable_list_guard(x, y)
677
741
  alpha, divisor = scalar_guard(alpha, divisor, x[0])
678
742
  _compilable_stochastic_add_divide_(x, y, alpha, divisor)
@@ -686,7 +750,7 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
686
750
  copy_stochastic_(x_, x32 * y32)
687
751
 
688
752
 
689
- def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
753
+ def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
690
754
  x, y = broadcastable_list_guard(x, y)
691
755
  _compilable_stochastic_multiply_(x, y)
692
756
 
@@ -832,6 +896,10 @@ class ExactHVPFailed(ValueError):
832
896
  use_default = object()
833
897
 
834
898
 
899
+ def _tensor_key(x: Tensor):
900
+ return x.data_ptr(), x.numel(), x.dtype, x.device
901
+
902
+
835
903
  class StatefulOptimizer(torch.optim.Optimizer):
836
904
  """
837
905
  finite_differences saves memory, but needs more compute. (Alternative is true HVP)
@@ -874,7 +942,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
874
942
 
875
943
  self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
876
944
  self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
877
- self._init_mapping()
878
945
 
879
946
  def _store_stats(self, state_dict: dict[str, any]):
880
947
  state_dict["heavyball"] = {
@@ -905,7 +972,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
905
972
  def state_(self, arg: Tensor, fail: bool = True):
906
973
  if not fail and arg not in self.mapping:
907
974
  return {}
908
- state_param, index = self.mapping_inverse[arg]
975
+ if _tensor_key(arg) not in self.mapping_inverse:
976
+ self._init_mapping()
977
+ state_param, index = self.mapping_inverse[_tensor_key(arg)]
909
978
  if state_param not in self.state:
910
979
  self.state[state_param] = collections.defaultdict(dict)
911
980
  return self.state[state_param][index]
@@ -928,7 +997,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
928
997
  if p not in self.mapping:
929
998
  self.mapping[p] = p_views = merge_group(group, p)
930
999
  for i, pv in enumerate(p_views):
931
- self.mapping_inverse[pv] = (p, i)
1000
+ self.mapping_inverse[_tensor_key(pv)] = (p, i)
932
1001
 
933
1002
  def split_p_and_g_in_group(
934
1003
  self,
@@ -949,12 +1018,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
949
1018
  yield p, grad
950
1019
  continue
951
1020
 
952
- if p in self.mapping:
953
- p_views = self.mapping[p]
954
- else:
955
- self.mapping[p] = p_views = merge_group(group, p)
956
- for i, pv in enumerate(p_views):
957
- self.mapping_inverse[pv] = (p, i)
1021
+ self.mapping[p] = p_views = merge_group(group, p)
1022
+ for i, pv in enumerate(p_views):
1023
+ self.mapping_inverse[_tensor_key(pv)] = (p, i)
958
1024
 
959
1025
  vector = getattr(p, "vector", None)
960
1026
  hessian_vector = getattr(p, "hessian_vector", None)
@@ -1199,17 +1265,53 @@ def _compilable_adam_(
1199
1265
 
1200
1266
 
1201
1267
  def adam_(
1268
+ exp_avg: List[Tensor] | Tensor,
1269
+ exp_avg_sq: List[Tensor] | Tensor,
1270
+ grad: List[Tensor] | Tensor,
1271
+ beta1: float,
1272
+ beta2: float,
1273
+ step: int,
1274
+ eps: float = 1e-8,
1275
+ ) -> List[Tensor]:
1276
+ exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
1277
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1278
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1279
+ return grad
1280
+
1281
+
1282
+ @decorator_knowngood
1283
+ def _compilable_unscaled_adam_(
1202
1284
  exp_avg: List[Tensor],
1203
1285
  exp_avg_sq: List[Tensor],
1204
1286
  grad: List[Tensor],
1287
+ beta1: Tensor,
1288
+ beta2: Tensor,
1289
+ step: Tensor,
1290
+ eps: Tensor,
1291
+ ):
1292
+ beta1 = beta_debias(beta1, step)
1293
+ beta2 = beta_debias(beta2, step)
1294
+
1295
+ g32 = list(map(promote, grad))
1296
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
1297
+ g32 = torch._foreach_div(g32, denom)
1298
+ exp_avg32 = _lerp(exp_avg, g32, beta1)
1299
+ u32 = torch._foreach_mul(exp_avg32, denom)
1300
+ copy_stochastic_list_(grad, u32)
1301
+
1302
+
1303
+ def unscaled_adam_(
1304
+ exp_avg: List[Tensor] | Tensor,
1305
+ exp_avg_sq: List[Tensor] | Tensor,
1306
+ grad: List[Tensor] | Tensor,
1205
1307
  beta1: float,
1206
1308
  beta2: float,
1207
1309
  step: int,
1208
1310
  eps: float = 1e-8,
1209
- ):
1311
+ ) -> List[Tensor]:
1210
1312
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
1211
1313
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1212
- _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1314
+ _compilable_unscaled_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1213
1315
  return grad
1214
1316
 
1215
1317
 
@@ -1253,7 +1355,7 @@ def fused_adam_(
1253
1355
  caution: bool,
1254
1356
  ):
1255
1357
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
1256
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
1358
+ beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
1257
1359
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
1258
1360
 
1259
1361
 
@@ -1332,7 +1434,7 @@ def fused_laprop_(
1332
1434
  eps: float = 1e-8,
1333
1435
  ):
1334
1436
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
1335
- beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
1437
+ beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
1336
1438
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
1337
1439
 
1338
1440
 
@@ -1351,7 +1453,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
1351
1453
 
1352
1454
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
1353
1455
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
1354
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
1456
+ beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
1355
1457
  _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
1356
1458
 
1357
1459
 
@@ -1381,11 +1483,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
1381
1483
 
1382
1484
 
1383
1485
  @decorator_knowngood
1384
- def stochastic_round_(ref: Tensor, source: Tensor):
1385
- if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1386
- return source
1387
- if ref.dtype != torch.bfloat16:
1388
- return source.to(ref.dtype)
1486
+ def stochastic_round_(ref: Tensor, source: Tensor | None = None):
1487
+ if source is not None:
1488
+ if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
1489
+ return source
1490
+ if ref.dtype != torch.bfloat16:
1491
+ return source.to(ref.dtype)
1492
+ else:
1493
+ source = ref
1494
+ source = source.float()
1389
1495
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
1390
1496
  result.add_(source.view(dtype=torch.int32))
1391
1497
  result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
@@ -1908,7 +2014,9 @@ def ndim_tuple(Q: list[Tensor]) -> tuple:
1908
2014
  return tuple(q.ndim for q in Q)
1909
2015
 
1910
2016
 
1911
- def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
2017
+ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
2018
+ if conjB is None:
2019
+ conjB = torch.randn_like(G)
1912
2020
  exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
1913
2021
  A = casted_einsum(exprA, *Q, G)
1914
2022
  solve = torch.compiler.disable(torch.linalg.solve_triangular)
@@ -1940,24 +2048,35 @@ def max_singular_value_exact(A, use_lobpcg: bool = False):
1940
2048
  eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
1941
2049
  return eigval[0].sqrt()
1942
2050
  else:
1943
- return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
1944
- except torch.linalg.LinAlgError:
1945
- return torch.zeros((), device=A.device, dtype=A.dtype)
2051
+ return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
2052
+ except (torch.linalg.LinAlgError, RuntimeError):
2053
+ return max_singular_value_power_iter(promote(A), iterations=2)
1946
2054
 
1947
2055
 
1948
2056
  @decorator_knowngood
1949
- def max_singular_value_power_iter(A: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
2057
+ def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
1950
2058
  """
1951
2059
  Rayleigh quotient of row with the largest norm + optional power iterations
1952
2060
  """
1953
- x_norm, max_idx = A.norm(dim=1).max(dim=0)
1954
- x = A.index_select(0, max_idx).flatten().contiguous()
1955
- A = A / x_norm
1956
- x = x / x_norm
1957
- for _ in range(iterations):
1958
- x = A.T.mv(A.mv(x)) # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
1959
- x = x / x.norm()
1960
- return (x @ A.T.mv(A.mv(x))).sqrt() * x_norm
2061
+ x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
2062
+ x_norm = promote(x_norm)
2063
+
2064
+ def _inner():
2065
+ A = A_outer
2066
+ x = A.index_select(0, max_idx).flatten().contiguous()
2067
+ A = stochastic_round_(A / x_norm)
2068
+ x = x / x_norm
2069
+
2070
+ def _mv(x):
2071
+ return promote(A.T.mv(A.mv(stochastic_round_(x))))
2072
+
2073
+ for _ in range(iterations):
2074
+ # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
2075
+ x = F.normalize(_mv(x), dim=0)
2076
+ out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
2077
+ return out.squeeze().clone()
2078
+
2079
+ return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
1961
2080
 
1962
2081
 
1963
2082
  @decorator_knowngood
@@ -1974,33 +2093,81 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
1974
2093
  return sketch_norm * max_abs
1975
2094
 
1976
2095
 
2096
+ def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
2097
+ if A.ndim <= 2:
2098
+ return max_singular_value(A, max_svd, use_cholesky, power_iter)
2099
+
2100
+ base = einsum_base[: A.ndim]
2101
+ A16 = stochastic_round_(A)
2102
+ squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
2103
+ svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
2104
+ svds = torch.stack(svds)
2105
+ return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
2106
+
2107
+
1977
2108
  @decorator_knowngood
1978
- def max_singular_value(
1979
- A: Tensor, max_abs: Optional[Tensor], max_svd: int = 32, use_cholesky: bool = False, power_iter: int = 0
1980
- ) -> Tensor:
2109
+ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
2110
+ if A.ndim < 2:
2111
+ return A.abs().max()
2112
+ if A.ndim > 2:
2113
+ raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
1981
2114
  if min(A.shape) <= max_svd:
1982
2115
  return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
1983
2116
  if use_cholesky or power_iter < 0:
1984
- return max_singular_value_cholesky(A, max_abs)
2117
+ return max_singular_value_cholesky(A)
1985
2118
  return max_singular_value_power_iter(A, None, iterations=power_iter)
1986
2119
 
1987
2120
 
1988
2121
  @decorator_knowngood
1989
- def _psgd_default_preconditioner_grad(
1990
- terms: List[Tuple[Tensor, Tensor]],
1991
- Q: List[Tensor],
1992
- ) -> List[Tensor]:
1993
- out = []
1994
- for q, (x, y) in zip(Q, terms):
1995
- x = promote(x)
1996
- y = promote(y)
1997
- update = x - y
1998
- if q.ndim < 2:
1999
- update = q * update
2000
- else:
2001
- update = (q @ update).triu()
2002
- out.append(update)
2003
- return out
2122
+ def clamped_max_singular_value(
2123
+ A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
2124
+ ) -> Tensor:
2125
+ norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
2126
+ out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
2127
+ return out.clamp(min=min)
2128
+
2129
+
2130
+ @decorator_knowngood
2131
+ def min_singular_value(
2132
+ A: Tensor,
2133
+ power_iter: int = 5,
2134
+ safety: float = 1.05,
2135
+ max_svd: int = 32,
2136
+ ):
2137
+ if A.ndim < 2:
2138
+ return A.abs().min()
2139
+
2140
+ n = A.size(0)
2141
+ if n <= max_svd:
2142
+ try:
2143
+ eigs = torch.linalg.eigvalsh(promote(A))
2144
+ return eigs.min().to(A.dtype)
2145
+ except torch.linalg.LinAlgError:
2146
+ pass
2147
+
2148
+ lambda_max_hat = max_singular_value(A, power_iter=power_iter)
2149
+ lambda_upper = lambda_max_hat * safety
2150
+
2151
+ row_norms = A.norm(dim=1)
2152
+ norm, idx = row_norms.min(dim=0)
2153
+ v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
2154
+
2155
+ v = v / promote(v.norm())
2156
+ for _ in range(power_iter):
2157
+ v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
2158
+ v = v / promote(v.norm())
2159
+ mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
2160
+
2161
+ lambda_min_hat = lambda_upper - mu_hat
2162
+
2163
+ def _approx():
2164
+ mu = A.trace() / n
2165
+ sigma_square = A.square().sum() / n - mu**2
2166
+ return mu - (sigma_square / (n - 1)).sqrt()
2167
+
2168
+ return cond(
2169
+ (~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
2170
+ ).squeeze()
2004
2171
 
2005
2172
 
2006
2173
  @decorator_knowngood
@@ -2027,6 +2194,107 @@ def calcG_expr(q_dim, g_dim):
2027
2194
  return exprs
2028
2195
 
2029
2196
 
2197
+ def eye_like(x: Tensor):
2198
+ if x.ndim < 2:
2199
+ return torch.ones_like(x)
2200
+ assert x.ndim == 2
2201
+ assert x.size(0) == x.size(1)
2202
+ return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
2203
+
2204
+
2205
+ @decorator_knowngood
2206
+ def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
2207
+ """
2208
+ Idea:
2209
+ G should be zeroth power. So, all Qs together should approximate the G's inverse.
2210
+ Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
2211
+ Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
2212
+ then backprop to L/R jointly.
2213
+ This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
2214
+
2215
+ `psgd_precond_grad` computes LGR for the general (n-dimensional) case
2216
+ `exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
2217
+ Args:
2218
+ G: Gradient that should be orthogonalized
2219
+ Q: List of preconditioner tensors.
2220
+
2221
+ Returns:
2222
+ - List of gradients with respect to Q (d_Q).
2223
+ """
2224
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
2225
+
2226
+ G16 = stochastic_round_(G)
2227
+ Q16 = [stochastic_round_(q) for q in Q]
2228
+ P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
2229
+
2230
+ d_P = torch.zeros_like(G)
2231
+ base = einsum_base[: G.ndim]
2232
+ for i, exprG in enumerate(exprGs):
2233
+ pp = compiled_einsum(exprG, P, P)
2234
+ error = pp - eye_like(pp)
2235
+ dim = einsum_base[i]
2236
+ if pp.ndim == 2:
2237
+ new = dim.upper()
2238
+ prec = f"{new}{dim}"
2239
+ else:
2240
+ new = dim
2241
+ prec = dim
2242
+ d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
2243
+
2244
+ d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
2245
+ grads = []
2246
+ for i, exprG in enumerate(exprGs):
2247
+ new_q = Q16[:]
2248
+ new_q[i] = eye_like(new_q[i])
2249
+ pq = psgd_precond_grad(G16, new_q)
2250
+ grad = compiled_einsum(exprG, pq, d_P)
2251
+ if grad.ndim == 2:
2252
+ grad = (grad + grad.T) / 2
2253
+ grads.append(grad)
2254
+
2255
+ return grads, P.to(G.dtype)
2256
+
2257
+
2258
+ def _inverse_initial_guess(gg):
2259
+ n = gg.shape[0]
2260
+
2261
+ sigma_max = promote(gg.norm())
2262
+
2263
+ trace_gg = promote(torch.trace(gg))
2264
+ sigma_min_approx = trace_gg / (n * sigma_max)
2265
+
2266
+ return sigma_max, sigma_min_approx
2267
+
2268
+
2269
+ @decorator_knowngood
2270
+ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
2271
+ k = torch.arange(degree, dtype=torch.float64, device=device)
2272
+ rotation = (2 * k + 1) * math.pi / (2 * degree)
2273
+ f = (rotation.cos() + 1 + eps) ** -0.5
2274
+ rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
2275
+ coeff0 = f.sum() / degree
2276
+ coeffs = f @ rotation * 2 / degree
2277
+ return coeff0.float(), coeffs.float()
2278
+
2279
+
2280
+ @decorator_knowngood
2281
+ def _psgd_default_preconditioner_grad(
2282
+ terms: List[Tuple[Tensor, Tensor]],
2283
+ Q: List[Tensor],
2284
+ ) -> List[Tensor]:
2285
+ out = []
2286
+ for q, (x, y) in zip(Q, terms):
2287
+ x = promote(x)
2288
+ y = promote(y)
2289
+ update = x - y
2290
+ if q.ndim < 2:
2291
+ update = q * update
2292
+ else:
2293
+ update = (q @ update).triu()
2294
+ out.append(update)
2295
+ return out
2296
+
2297
+
2030
2298
  @decorator
2031
2299
  def psgd_update_precond(
2032
2300
  G: Tensor,
@@ -2056,6 +2324,102 @@ def psgd_update_precond(
2056
2324
  return None
2057
2325
 
2058
2326
 
2327
+ @decorator_knowngood
2328
+ def bf16_matmul(x: Tensor, y: Tensor):
2329
+ return (promote(x) @ promote(y)).to(x.dtype)
2330
+
2331
+
2332
+ def if_iscompiling(fn):
2333
+ base = getattr(torch, fn.__name__, None)
2334
+
2335
+ def _fn(x):
2336
+ if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
2337
+ return base(x)
2338
+ return fn(x)
2339
+
2340
+ return _fn
2341
+
2342
+
2343
+ @if_iscompiling
2344
+ def while_loop(cond, body, state):
2345
+ """
2346
+ dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
2347
+ useful for debugging
2348
+ """
2349
+ while cond(*state).item():
2350
+ state = body(*state)
2351
+ return state
2352
+
2353
+
2354
+ @if_iscompiling
2355
+ def cond(cond, true_fn, false_fn):
2356
+ """
2357
+ dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
2358
+ useful for debugging
2359
+ """
2360
+
2361
+ if cond.item():
2362
+ return true_fn()
2363
+ return false_fn()
2364
+
2365
+
2366
+ @decorator_knowngood
2367
+ def _householder_vec_e1_to_v(v: Tensor, eps: float = 1e-12) -> Tensor:
2368
+ """
2369
+ Return w such that H = I - 2 w w^T is orthogonal and H e1 = v (v unit).
2370
+ Applying from the right: G @ H = G - 2 (G @ w) w^T.
2371
+ If v is (numerically) e1, returns w=0 and H=I.
2372
+ """
2373
+ v = v / v.norm().clamp(min=eps)
2374
+ e1 = torch.zeros_like(v)
2375
+ e1[0] = 1.0
2376
+ w = e1 - v
2377
+ return w / w.norm().clamp(min=eps)
2378
+
2379
+
2380
+ @decorator_knowngood
2381
+ def eigvecs_product_rank1(
2382
+ G: Tensor, v: Tensor, w: Optional[Tensor] = None, eps: float = 1e-12
2383
+ ) -> Tuple[Tensor, Tensor]:
2384
+ """
2385
+ Compute Y = G @ V where V is an eigenvector matrix for P = λ I + σ v v^T,
2386
+ using the Householder reflector with first column v. Never materializes V.
2387
+
2388
+ Args:
2389
+ G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
2390
+ v: shape (d,) — current unit direction (top eigenvector of P).
2391
+ w: optional Householder vector w; pass to reuse across calls.
2392
+
2393
+ Returns:
2394
+ (Y, w) where:
2395
+ Y has shape (..., d) and equals G @ eigenvectors(P),
2396
+ w is the Householder vector you can cache & reuse.
2397
+ """
2398
+ if w is None:
2399
+ w = _householder_vec_e1_to_v(v, eps)
2400
+ Y = G - 2.0 * compiled_einsum("...i,i,j->...j", G, w, w)
2401
+ return Y, w
2402
+
2403
+
2404
+ @decorator_knowngood
2405
+ def oja_update(v: Tensor, g: Tensor, lr: float = 1e-2, eps: float = 1e-12) -> Tensor:
2406
+ """
2407
+ One Oja step to track the top eigendirection of the gradient covariance.
2408
+ v <- v + lr * ((g^T v) g - (g^T v)^2 v); then renormalize.
2409
+ """
2410
+ gv = g @ v
2411
+ v = v + lr * (gv * g - (gv * gv) * v)
2412
+ return v / v.norm().clamp(min=eps)
2413
+
2414
+
2415
+ def cond_n(cond_val: Tensor, *fns):
2416
+ fns = list(fns)
2417
+ fn = fns.pop(0)
2418
+ if not fns:
2419
+ return fn
2420
+ return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
2421
+
2422
+
2059
2423
  @decorator_knowngood
2060
2424
  def _psgd_precond_update_(
2061
2425
  matmuled: List[Optional[Tensor]],
@@ -2074,7 +2438,7 @@ def _psgd_precond_update_(
2074
2438
  if update.ndim < 2:
2075
2439
  lb = update.norm(float("inf"))
2076
2440
  else:
2077
- lb = max_singular_value(update, None, power_iter=power_iter)
2441
+ lb = max_singular_value(update, power_iter=power_iter)
2078
2442
  update = promote(update)
2079
2443
  if store_triu_as_line:
2080
2444
  update = triu_to_line([update])[0][1]
@@ -2146,70 +2510,83 @@ def inverse_free_psgd_update_precond(
2146
2510
 
2147
2511
 
2148
2512
  @decorator_knowngood
2149
- def _compilable_l2_clip_(x, clip_at):
2150
- ref = x
2151
- x = list(map(promote, x))
2152
- norm = torch._foreach_norm(x)
2153
- torch._foreach_maximum_(norm, clip_at)
2154
- out = torch._foreach_div(x, norm)
2155
- return stochastic_round_list_(ref, out)
2513
+ def _clip(x, norm, clip_at, eps=1e-8):
2514
+ x32 = promote(x)
2515
+ # (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
2516
+ norm = clip_at / norm.clamp(min=max(clip_at, eps))
2517
+ x32 = x32 * norm
2518
+ copy_stochastic_(x, x32)
2519
+
2520
+
2521
+ @decorator_knowngood
2522
+ def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
2523
+ for x in xs:
2524
+ _clip(x, promote(x).norm(), clip_at, eps)
2156
2525
 
2157
2526
 
2158
2527
  def l2_normalization_(x, clip_at: float = 1e-8):
2159
2528
  x = list_guard(x)
2160
- return _compilable_l2_clip_(x, clip_at)
2529
+ _compilable_l2_clip_(x, clip_at)
2530
+ return x
2161
2531
 
2162
2532
 
2163
2533
  def l2_clip_(x, clip_at: float = 1.0):
2164
2534
  x = list_guard(x)
2165
- return _compilable_l2_clip_(x, clip_at)
2535
+ _compilable_l2_clip_(x, clip_at)
2536
+ return x
2166
2537
 
2167
2538
 
2168
2539
  @decorator_knowngood
2169
- def _compilable_rmsnorm_clip_(x, clip_at):
2170
- x = list(map(promote, x))
2171
- norm = torch._foreach_norm(x)
2172
- norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
2173
- torch._foreach_maximum_(norm, clip_at)
2174
- return torch._foreach_div(x, norm)
2540
+ def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
2541
+ for x in xs:
2542
+ _clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
2175
2543
 
2176
2544
 
2177
2545
  def rmsnorm_clip_(x, clip_at: float = 1.0):
2178
2546
  x = list_guard(x)
2179
- return _compilable_rmsnorm_clip_(x, clip_at)
2180
-
2181
-
2182
- @decorator_knowngood
2183
- def _compilable_global_rmsnorm_clip_(x, clip_at):
2184
- x = list(map(promote, x))
2185
- norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
2186
- norm = norm**0.5
2187
- norm = norm.clamp(min=clip_at)
2188
- return torch._foreach_div(x, norm)
2547
+ _compilable_rmsnorm_clip_(x, clip_at)
2548
+ return x
2189
2549
 
2190
2550
 
2191
2551
  @decorator_knowngood
2192
- def _compilable_global_l2norm_clip_(x, clip_at):
2193
- x = list(map(promote, x))
2194
- norm = sum([x.square().sum() for x in x])
2195
- norm = norm**0.5
2196
- norm = norm.clamp(min=clip_at)
2197
- return torch._foreach_div(x, norm)
2552
+ def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
2553
+ norm = 0
2554
+ numel = sum([i.numel() for i in x])
2555
+ for i in x:
2556
+ norm += promote(i).square().sum()
2557
+ norm = (norm / numel) ** 0.5
2558
+ scalar = clip_at / norm.clamp(min=max(clip_at, eps))
2559
+ stochastic_multiply_(x, scalar)
2198
2560
 
2199
2561
 
2200
2562
  def global_rmsnorm_clip(x, clip_at: float = 1.0):
2201
2563
  x = list_guard(x)
2202
- return _compilable_global_rmsnorm_clip_(x, clip_at)
2564
+ clip_at = scalar_guard(clip_at, x[0])
2565
+ _compilable_global_rmsnorm_clip_(x, clip_at)
2566
+ return x
2567
+
2568
+
2569
+ @decorator_knowngood
2570
+ def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
2571
+ norm = 0
2572
+ for i in x:
2573
+ norm += promote(i).square().sum()
2574
+ norm = norm**0.5
2575
+ scalar = clip_at / norm.clamp(min=max(clip_at, eps))
2576
+ stochastic_multiply_(x, scalar)
2203
2577
 
2204
2578
 
2205
2579
  def global_l2norm_clip(x, clip_at: float = 1.0):
2206
2580
  x = list_guard(x)
2207
- return _compilable_global_rmsnorm_clip_(x, clip_at)
2581
+ clip_at = scalar_guard(clip_at, x[0])
2582
+ _compilable_global_l2norm_clip_(x, clip_at)
2583
+ return x
2208
2584
 
2209
2585
 
2210
2586
  def rmsnorm_normalize_(x, clip_at: float = 1e-6):
2211
2587
  x = list_guard(x)
2212
- return _compilable_rmsnorm_clip_(x, clip_at)
2588
+ _compilable_rmsnorm_clip_(x, clip_at)
2589
+ return x
2213
2590
 
2214
2591
 
2215
2592
  @decorator_knowngood
@@ -2284,17 +2661,6 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
2284
2661
  _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
2285
2662
 
2286
2663
 
2287
- @decorator_knowngood
2288
- def _compilable_weight_decay_to_init_(p, init, weight_decay):
2289
- _lerp(p, promote(init), 1 - weight_decay)
2290
-
2291
-
2292
- def weight_decay_to_init_(p, init, weight_decay):
2293
- p, init = list_guard(p, init)
2294
- weight_decay = scalar_guard(weight_decay, p[0])
2295
- _compilable_weight_decay_to_ema_(p, init, weight_decay)
2296
-
2297
-
2298
2664
  @decorator_knowngood
2299
2665
  def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
2300
2666
  ema32 = _lerp(ema, p, ema_decay)
@@ -2416,7 +2782,7 @@ def precond_grad_cached_(
2416
2782
  md = min_dtype(list(cached_q) + [ea])
2417
2783
  args = [q.to(md) for q in cached_q]
2418
2784
  args = args + [ea.to(md)]
2419
- expr = cached_precond_grad_expr(ndim_tuple(cached_q), grad.ndim)
2785
+ expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
2420
2786
  new = compiled_einsum(expr, *args)
2421
2787
  if cast:
2422
2788
  return new.to(ea.dtype)
@@ -2433,7 +2799,7 @@ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, c
2433
2799
 
2434
2800
 
2435
2801
  def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
2436
- lr = scalar_guard(lr, param[0])
2802
+ lr, decay = scalar_guard(lr, decay, param[0])
2437
2803
  _compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
2438
2804
 
2439
2805
 
@@ -2502,7 +2868,7 @@ def fused_psgd_precond_grad(
2502
2868
  store_triu_as_line: bool = False,
2503
2869
  symmetric_output: bool = False,
2504
2870
  ):
2505
- lr = scalar_guard(lr, param[0])
2871
+ lr, decay = scalar_guard(lr, decay, param[0])
2506
2872
  _compilable_fused_psgd_precond_grad(
2507
2873
  ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
2508
2874
  )