heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +168 -29
- heavyball/chainable.py +165 -63
- heavyball/helpers.py +5 -1
- heavyball/utils.py +490 -124
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/METADATA +19 -7
- heavyball-2.1.0.dist-info/RECORD +9 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/WHEEL +1 -1
- heavyball-2.0.0.dev0.dist-info/RECORD +0 -9
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import contextlib
|
3
|
+
import enum
|
3
4
|
import functools
|
4
5
|
import gc
|
5
6
|
import inspect
|
@@ -16,13 +17,29 @@ import torch
|
|
16
17
|
from torch import Tensor
|
17
18
|
from torch._dynamo.exc import TorchDynamoException
|
18
19
|
from torch.backends import cudnn, opt_einsum
|
20
|
+
from torch.nn import functional as F
|
19
21
|
from torch.utils._pytree import tree_map
|
20
22
|
|
23
|
+
|
24
|
+
class ZerothPowerMode(enum.Enum):
|
25
|
+
newtonschulz = "newtonschulz"
|
26
|
+
legacy_newtonschulz = "legacy_newtonschulz"
|
27
|
+
qr = "qr"
|
28
|
+
svd = "svd"
|
29
|
+
legacy_svd = "legacy_svd"
|
30
|
+
|
31
|
+
|
32
|
+
class OrthoScaleMode(enum.Enum):
|
33
|
+
none = "none"
|
34
|
+
scale = "scale"
|
35
|
+
graft = "graft"
|
36
|
+
|
37
|
+
|
21
38
|
compile_mode = "max-autotune-no-cudagraphs"
|
22
39
|
dynamic = False
|
23
40
|
compile_mode_recommended_to_none = None
|
24
41
|
zeroth_power_mode = "newtonschulz"
|
25
|
-
precise_zeroth_power_mode = "qr"
|
42
|
+
precise_zeroth_power_mode = "qr"
|
26
43
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
44
|
_cudnn_double_backward_pattern = re.compile(
|
28
45
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -240,14 +257,14 @@ def eps_sqrt(item, eps):
|
|
240
257
|
|
241
258
|
@decorator_knowngood
|
242
259
|
def _compilable_exp_avg_sq_(
|
243
|
-
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[
|
260
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
|
244
261
|
):
|
245
262
|
g32 = promote(grad)
|
246
263
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
247
264
|
|
248
265
|
denom = [eps_sqrt(d, eps) for d in s32]
|
249
266
|
|
250
|
-
if out[0] is None:
|
267
|
+
if out is None or out[0] is None:
|
251
268
|
return denom
|
252
269
|
|
253
270
|
copy_stochastic_list_(out, denom)
|
@@ -316,8 +333,8 @@ def adaptive_gradient_clipping_(
|
|
316
333
|
def is_compiling():
|
317
334
|
try:
|
318
335
|
return torch.compiler.is_compiling()
|
319
|
-
except TorchDynamoException:
|
320
|
-
return
|
336
|
+
except (TorchDynamoException, AttributeError):
|
337
|
+
return False
|
321
338
|
|
322
339
|
|
323
340
|
def set_(dst: Tensor, src: Tensor):
|
@@ -366,12 +383,42 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
|
|
366
383
|
)
|
367
384
|
|
368
385
|
|
369
|
-
@
|
386
|
+
@decorator_knowngood
|
370
387
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
388
|
+
assert (
|
389
|
+
G.ndim >= 2
|
390
|
+
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
391
|
+
assert steps == 5
|
392
|
+
G = G.clone()
|
393
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
394
|
+
if G.size(-2) > G.size(-1):
|
395
|
+
X = X.mT
|
396
|
+
|
397
|
+
stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
|
398
|
+
# Perform the NS iterations
|
399
|
+
for a, b, c in [
|
400
|
+
(4.0848, -6.8946, 2.9270),
|
401
|
+
(3.9505, -6.3029, 2.6377),
|
402
|
+
(3.7418, -5.5913, 2.3037),
|
403
|
+
(2.8769, -3.1427, 1.2046),
|
404
|
+
(2.8366, -3.0525, 1.2012),
|
405
|
+
]:
|
406
|
+
A = X @ X.mT
|
407
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
408
|
+
X = a * X + B @ X
|
409
|
+
|
410
|
+
if G.size(-2) > G.size(-1):
|
411
|
+
X = X.mT
|
412
|
+
return X.to(G.dtype)
|
413
|
+
|
414
|
+
|
415
|
+
@decorator_knowngood
|
416
|
+
def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
371
417
|
assert len(G.shape) == 2
|
372
418
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
373
|
-
|
374
|
-
X
|
419
|
+
G = G.clone()
|
420
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
421
|
+
stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
|
375
422
|
if G.size(0) > G.size(1):
|
376
423
|
X = X.T
|
377
424
|
for _ in range(steps):
|
@@ -435,21 +482,30 @@ def _compilable_grafting(magnitude, direction):
|
|
435
482
|
|
436
483
|
|
437
484
|
@decorator_knowngood
|
438
|
-
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
439
|
-
if mode
|
485
|
+
def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
|
486
|
+
if not isinstance(mode, ZerothPowerMode):
|
487
|
+
mode = ZerothPowerMode(mode)
|
488
|
+
if not isinstance(scale_mode, ZerothPowerMode):
|
489
|
+
scale_mode = OrthoScaleMode(scale_mode)
|
490
|
+
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
|
440
491
|
y = zeropower_via_newtonschulz5(x, 5)
|
441
|
-
elif mode ==
|
492
|
+
elif mode == ZerothPowerMode.legacy_newtonschulz:
|
493
|
+
y = legacy_zeropower_via_newtonschulz5(x, 5)
|
494
|
+
elif mode == ZerothPowerMode.qr:
|
442
495
|
y = torch.linalg.qr(promote(x)).Q
|
443
|
-
elif mode ==
|
444
|
-
u, _s,
|
445
|
-
y = u @
|
496
|
+
elif mode == ZerothPowerMode.svd:
|
497
|
+
u, _s, vt = torch.linalg.svd(promote(x))
|
498
|
+
y = u @ vt
|
499
|
+
elif mode == ZerothPowerMode.legacy_svd:
|
500
|
+
u, _s, vt = torch.linalg.svd(promote(x))
|
501
|
+
y = u @ vt.T
|
446
502
|
else:
|
447
503
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
448
|
-
if scale_mode ==
|
504
|
+
if scale_mode == OrthoScaleMode.none:
|
449
505
|
pass
|
450
|
-
elif scale_mode ==
|
451
|
-
y *= max(1, x.size(
|
452
|
-
elif scale_mode ==
|
506
|
+
elif scale_mode == OrthoScaleMode.scale:
|
507
|
+
y *= max(1, x.size(-2) / x.size(-1)) ** 0.5
|
508
|
+
elif scale_mode == OrthoScaleMode.graft:
|
453
509
|
y = _compilable_grafting(x, y)
|
454
510
|
else:
|
455
511
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
@@ -556,10 +612,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
556
612
|
except torch.OutOfMemoryError:
|
557
613
|
if m.device.type == "cpu":
|
558
614
|
raise
|
559
|
-
|
615
|
+
if torch.cuda.is_available():
|
616
|
+
torch.cuda.synchronize(m.device)
|
617
|
+
clean()
|
618
|
+
m = m.cpu()
|
619
|
+
except RuntimeError as e:
|
620
|
+
if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
|
621
|
+
torch.cuda.synchronize(m.device)
|
622
|
+
clean()
|
560
623
|
m = m.cpu()
|
561
|
-
|
562
|
-
if m.dtype != torch.double:
|
624
|
+
elif m.dtype != torch.double:
|
563
625
|
m = m.double()
|
564
626
|
elif eps < max_eps:
|
565
627
|
eps = eps ** (2 / 3)
|
@@ -658,7 +720,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
658
720
|
copy_stochastic_(x_, x32 + y32 * alpha)
|
659
721
|
|
660
722
|
|
661
|
-
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
723
|
+
def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
|
662
724
|
x, y = broadcastable_list_guard(x, y)
|
663
725
|
alpha = scalar_guard(alpha, x[0])
|
664
726
|
_compilable_stochastic_add_(x, y, alpha)
|
@@ -672,7 +734,9 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
|
|
672
734
|
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
673
735
|
|
674
736
|
|
675
|
-
def stochastic_add_divide_(
|
737
|
+
def stochastic_add_divide_(
|
738
|
+
x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
|
739
|
+
):
|
676
740
|
x, y = broadcastable_list_guard(x, y)
|
677
741
|
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
678
742
|
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
@@ -686,7 +750,7 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
686
750
|
copy_stochastic_(x_, x32 * y32)
|
687
751
|
|
688
752
|
|
689
|
-
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
753
|
+
def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
|
690
754
|
x, y = broadcastable_list_guard(x, y)
|
691
755
|
_compilable_stochastic_multiply_(x, y)
|
692
756
|
|
@@ -832,6 +896,10 @@ class ExactHVPFailed(ValueError):
|
|
832
896
|
use_default = object()
|
833
897
|
|
834
898
|
|
899
|
+
def _tensor_key(x: Tensor):
|
900
|
+
return x.data_ptr(), x.numel(), x.dtype, x.device
|
901
|
+
|
902
|
+
|
835
903
|
class StatefulOptimizer(torch.optim.Optimizer):
|
836
904
|
"""
|
837
905
|
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
@@ -874,7 +942,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
874
942
|
|
875
943
|
self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
|
876
944
|
self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
|
877
|
-
self._init_mapping()
|
878
945
|
|
879
946
|
def _store_stats(self, state_dict: dict[str, any]):
|
880
947
|
state_dict["heavyball"] = {
|
@@ -905,7 +972,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
905
972
|
def state_(self, arg: Tensor, fail: bool = True):
|
906
973
|
if not fail and arg not in self.mapping:
|
907
974
|
return {}
|
908
|
-
|
975
|
+
if _tensor_key(arg) not in self.mapping_inverse:
|
976
|
+
self._init_mapping()
|
977
|
+
state_param, index = self.mapping_inverse[_tensor_key(arg)]
|
909
978
|
if state_param not in self.state:
|
910
979
|
self.state[state_param] = collections.defaultdict(dict)
|
911
980
|
return self.state[state_param][index]
|
@@ -928,7 +997,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
928
997
|
if p not in self.mapping:
|
929
998
|
self.mapping[p] = p_views = merge_group(group, p)
|
930
999
|
for i, pv in enumerate(p_views):
|
931
|
-
self.mapping_inverse[pv] = (p, i)
|
1000
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
932
1001
|
|
933
1002
|
def split_p_and_g_in_group(
|
934
1003
|
self,
|
@@ -949,12 +1018,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
949
1018
|
yield p, grad
|
950
1019
|
continue
|
951
1020
|
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
self.mapping[p] = p_views = merge_group(group, p)
|
956
|
-
for i, pv in enumerate(p_views):
|
957
|
-
self.mapping_inverse[pv] = (p, i)
|
1021
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
1022
|
+
for i, pv in enumerate(p_views):
|
1023
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
958
1024
|
|
959
1025
|
vector = getattr(p, "vector", None)
|
960
1026
|
hessian_vector = getattr(p, "hessian_vector", None)
|
@@ -1199,17 +1265,53 @@ def _compilable_adam_(
|
|
1199
1265
|
|
1200
1266
|
|
1201
1267
|
def adam_(
|
1268
|
+
exp_avg: List[Tensor] | Tensor,
|
1269
|
+
exp_avg_sq: List[Tensor] | Tensor,
|
1270
|
+
grad: List[Tensor] | Tensor,
|
1271
|
+
beta1: float,
|
1272
|
+
beta2: float,
|
1273
|
+
step: int,
|
1274
|
+
eps: float = 1e-8,
|
1275
|
+
) -> List[Tensor]:
|
1276
|
+
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
1277
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1278
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
1279
|
+
return grad
|
1280
|
+
|
1281
|
+
|
1282
|
+
@decorator_knowngood
|
1283
|
+
def _compilable_unscaled_adam_(
|
1202
1284
|
exp_avg: List[Tensor],
|
1203
1285
|
exp_avg_sq: List[Tensor],
|
1204
1286
|
grad: List[Tensor],
|
1287
|
+
beta1: Tensor,
|
1288
|
+
beta2: Tensor,
|
1289
|
+
step: Tensor,
|
1290
|
+
eps: Tensor,
|
1291
|
+
):
|
1292
|
+
beta1 = beta_debias(beta1, step)
|
1293
|
+
beta2 = beta_debias(beta2, step)
|
1294
|
+
|
1295
|
+
g32 = list(map(promote, grad))
|
1296
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
1297
|
+
g32 = torch._foreach_div(g32, denom)
|
1298
|
+
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
1299
|
+
u32 = torch._foreach_mul(exp_avg32, denom)
|
1300
|
+
copy_stochastic_list_(grad, u32)
|
1301
|
+
|
1302
|
+
|
1303
|
+
def unscaled_adam_(
|
1304
|
+
exp_avg: List[Tensor] | Tensor,
|
1305
|
+
exp_avg_sq: List[Tensor] | Tensor,
|
1306
|
+
grad: List[Tensor] | Tensor,
|
1205
1307
|
beta1: float,
|
1206
1308
|
beta2: float,
|
1207
1309
|
step: int,
|
1208
1310
|
eps: float = 1e-8,
|
1209
|
-
):
|
1311
|
+
) -> List[Tensor]:
|
1210
1312
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
1211
1313
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1212
|
-
|
1314
|
+
_compilable_unscaled_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
1213
1315
|
return grad
|
1214
1316
|
|
1215
1317
|
|
@@ -1253,7 +1355,7 @@ def fused_adam_(
|
|
1253
1355
|
caution: bool,
|
1254
1356
|
):
|
1255
1357
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
1256
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
1358
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
|
1257
1359
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
1258
1360
|
|
1259
1361
|
|
@@ -1332,7 +1434,7 @@ def fused_laprop_(
|
|
1332
1434
|
eps: float = 1e-8,
|
1333
1435
|
):
|
1334
1436
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1335
|
-
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
1437
|
+
beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
|
1336
1438
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
1337
1439
|
|
1338
1440
|
|
@@ -1351,7 +1453,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
1351
1453
|
|
1352
1454
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
1353
1455
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1354
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
1456
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
|
1355
1457
|
_fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
1356
1458
|
|
1357
1459
|
|
@@ -1381,11 +1483,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
|
1381
1483
|
|
1382
1484
|
|
1383
1485
|
@decorator_knowngood
|
1384
|
-
def stochastic_round_(ref: Tensor, source: Tensor):
|
1385
|
-
if source
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1486
|
+
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
|
1487
|
+
if source is not None:
|
1488
|
+
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
1489
|
+
return source
|
1490
|
+
if ref.dtype != torch.bfloat16:
|
1491
|
+
return source.to(ref.dtype)
|
1492
|
+
else:
|
1493
|
+
source = ref
|
1494
|
+
source = source.float()
|
1389
1495
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
1390
1496
|
result.add_(source.view(dtype=torch.int32))
|
1391
1497
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
@@ -1908,7 +2014,9 @@ def ndim_tuple(Q: list[Tensor]) -> tuple:
|
|
1908
2014
|
return tuple(q.ndim for q in Q)
|
1909
2015
|
|
1910
2016
|
|
1911
|
-
def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
2017
|
+
def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
|
2018
|
+
if conjB is None:
|
2019
|
+
conjB = torch.randn_like(G)
|
1912
2020
|
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
|
1913
2021
|
A = casted_einsum(exprA, *Q, G)
|
1914
2022
|
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
@@ -1940,24 +2048,35 @@ def max_singular_value_exact(A, use_lobpcg: bool = False):
|
|
1940
2048
|
eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
|
1941
2049
|
return eigval[0].sqrt()
|
1942
2050
|
else:
|
1943
|
-
return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
|
1944
|
-
except torch.linalg.LinAlgError:
|
1945
|
-
return
|
2051
|
+
return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
|
2052
|
+
except (torch.linalg.LinAlgError, RuntimeError):
|
2053
|
+
return max_singular_value_power_iter(promote(A), iterations=2)
|
1946
2054
|
|
1947
2055
|
|
1948
2056
|
@decorator_knowngood
|
1949
|
-
def max_singular_value_power_iter(
|
2057
|
+
def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
|
1950
2058
|
"""
|
1951
2059
|
Rayleigh quotient of row with the largest norm + optional power iterations
|
1952
2060
|
"""
|
1953
|
-
x_norm, max_idx =
|
1954
|
-
|
1955
|
-
|
1956
|
-
|
1957
|
-
|
1958
|
-
x = A.
|
1959
|
-
|
1960
|
-
|
2061
|
+
x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
|
2062
|
+
x_norm = promote(x_norm)
|
2063
|
+
|
2064
|
+
def _inner():
|
2065
|
+
A = A_outer
|
2066
|
+
x = A.index_select(0, max_idx).flatten().contiguous()
|
2067
|
+
A = stochastic_round_(A / x_norm)
|
2068
|
+
x = x / x_norm
|
2069
|
+
|
2070
|
+
def _mv(x):
|
2071
|
+
return promote(A.T.mv(A.mv(stochastic_round_(x))))
|
2072
|
+
|
2073
|
+
for _ in range(iterations):
|
2074
|
+
# A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
|
2075
|
+
x = F.normalize(_mv(x), dim=0)
|
2076
|
+
out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
|
2077
|
+
return out.squeeze().clone()
|
2078
|
+
|
2079
|
+
return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
|
1961
2080
|
|
1962
2081
|
|
1963
2082
|
@decorator_knowngood
|
@@ -1974,33 +2093,81 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
|
1974
2093
|
return sketch_norm * max_abs
|
1975
2094
|
|
1976
2095
|
|
2096
|
+
def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2097
|
+
if A.ndim <= 2:
|
2098
|
+
return max_singular_value(A, max_svd, use_cholesky, power_iter)
|
2099
|
+
|
2100
|
+
base = einsum_base[: A.ndim]
|
2101
|
+
A16 = stochastic_round_(A)
|
2102
|
+
squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
|
2103
|
+
svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
|
2104
|
+
svds = torch.stack(svds)
|
2105
|
+
return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
|
2106
|
+
|
2107
|
+
|
1977
2108
|
@decorator_knowngood
|
1978
|
-
def max_singular_value(
|
1979
|
-
A
|
1980
|
-
)
|
2109
|
+
def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2110
|
+
if A.ndim < 2:
|
2111
|
+
return A.abs().max()
|
2112
|
+
if A.ndim > 2:
|
2113
|
+
raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
|
1981
2114
|
if min(A.shape) <= max_svd:
|
1982
2115
|
return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
|
1983
2116
|
if use_cholesky or power_iter < 0:
|
1984
|
-
return max_singular_value_cholesky(A
|
2117
|
+
return max_singular_value_cholesky(A)
|
1985
2118
|
return max_singular_value_power_iter(A, None, iterations=power_iter)
|
1986
2119
|
|
1987
2120
|
|
1988
2121
|
@decorator_knowngood
|
1989
|
-
def
|
1990
|
-
|
1991
|
-
|
1992
|
-
)
|
1993
|
-
out =
|
1994
|
-
|
1995
|
-
|
1996
|
-
|
1997
|
-
|
1998
|
-
|
1999
|
-
|
2000
|
-
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2122
|
+
def clamped_max_singular_value(
|
2123
|
+
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
|
2124
|
+
) -> Tensor:
|
2125
|
+
norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
|
2126
|
+
out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
|
2127
|
+
return out.clamp(min=min)
|
2128
|
+
|
2129
|
+
|
2130
|
+
@decorator_knowngood
|
2131
|
+
def min_singular_value(
|
2132
|
+
A: Tensor,
|
2133
|
+
power_iter: int = 5,
|
2134
|
+
safety: float = 1.05,
|
2135
|
+
max_svd: int = 32,
|
2136
|
+
):
|
2137
|
+
if A.ndim < 2:
|
2138
|
+
return A.abs().min()
|
2139
|
+
|
2140
|
+
n = A.size(0)
|
2141
|
+
if n <= max_svd:
|
2142
|
+
try:
|
2143
|
+
eigs = torch.linalg.eigvalsh(promote(A))
|
2144
|
+
return eigs.min().to(A.dtype)
|
2145
|
+
except torch.linalg.LinAlgError:
|
2146
|
+
pass
|
2147
|
+
|
2148
|
+
lambda_max_hat = max_singular_value(A, power_iter=power_iter)
|
2149
|
+
lambda_upper = lambda_max_hat * safety
|
2150
|
+
|
2151
|
+
row_norms = A.norm(dim=1)
|
2152
|
+
norm, idx = row_norms.min(dim=0)
|
2153
|
+
v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
|
2154
|
+
|
2155
|
+
v = v / promote(v.norm())
|
2156
|
+
for _ in range(power_iter):
|
2157
|
+
v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
|
2158
|
+
v = v / promote(v.norm())
|
2159
|
+
mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
|
2160
|
+
|
2161
|
+
lambda_min_hat = lambda_upper - mu_hat
|
2162
|
+
|
2163
|
+
def _approx():
|
2164
|
+
mu = A.trace() / n
|
2165
|
+
sigma_square = A.square().sum() / n - mu**2
|
2166
|
+
return mu - (sigma_square / (n - 1)).sqrt()
|
2167
|
+
|
2168
|
+
return cond(
|
2169
|
+
(~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
|
2170
|
+
).squeeze()
|
2004
2171
|
|
2005
2172
|
|
2006
2173
|
@decorator_knowngood
|
@@ -2027,6 +2194,107 @@ def calcG_expr(q_dim, g_dim):
|
|
2027
2194
|
return exprs
|
2028
2195
|
|
2029
2196
|
|
2197
|
+
def eye_like(x: Tensor):
|
2198
|
+
if x.ndim < 2:
|
2199
|
+
return torch.ones_like(x)
|
2200
|
+
assert x.ndim == 2
|
2201
|
+
assert x.size(0) == x.size(1)
|
2202
|
+
return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
|
2203
|
+
|
2204
|
+
|
2205
|
+
@decorator_knowngood
|
2206
|
+
def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
|
2207
|
+
"""
|
2208
|
+
Idea:
|
2209
|
+
G should be zeroth power. So, all Qs together should approximate the G's inverse.
|
2210
|
+
Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
|
2211
|
+
Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
|
2212
|
+
then backprop to L/R jointly.
|
2213
|
+
This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
|
2214
|
+
|
2215
|
+
`psgd_precond_grad` computes LGR for the general (n-dimensional) case
|
2216
|
+
`exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
|
2217
|
+
Args:
|
2218
|
+
G: Gradient that should be orthogonalized
|
2219
|
+
Q: List of preconditioner tensors.
|
2220
|
+
|
2221
|
+
Returns:
|
2222
|
+
- List of gradients with respect to Q (d_Q).
|
2223
|
+
"""
|
2224
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2225
|
+
|
2226
|
+
G16 = stochastic_round_(G)
|
2227
|
+
Q16 = [stochastic_round_(q) for q in Q]
|
2228
|
+
P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
|
2229
|
+
|
2230
|
+
d_P = torch.zeros_like(G)
|
2231
|
+
base = einsum_base[: G.ndim]
|
2232
|
+
for i, exprG in enumerate(exprGs):
|
2233
|
+
pp = compiled_einsum(exprG, P, P)
|
2234
|
+
error = pp - eye_like(pp)
|
2235
|
+
dim = einsum_base[i]
|
2236
|
+
if pp.ndim == 2:
|
2237
|
+
new = dim.upper()
|
2238
|
+
prec = f"{new}{dim}"
|
2239
|
+
else:
|
2240
|
+
new = dim
|
2241
|
+
prec = dim
|
2242
|
+
d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
|
2243
|
+
|
2244
|
+
d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
|
2245
|
+
grads = []
|
2246
|
+
for i, exprG in enumerate(exprGs):
|
2247
|
+
new_q = Q16[:]
|
2248
|
+
new_q[i] = eye_like(new_q[i])
|
2249
|
+
pq = psgd_precond_grad(G16, new_q)
|
2250
|
+
grad = compiled_einsum(exprG, pq, d_P)
|
2251
|
+
if grad.ndim == 2:
|
2252
|
+
grad = (grad + grad.T) / 2
|
2253
|
+
grads.append(grad)
|
2254
|
+
|
2255
|
+
return grads, P.to(G.dtype)
|
2256
|
+
|
2257
|
+
|
2258
|
+
def _inverse_initial_guess(gg):
|
2259
|
+
n = gg.shape[0]
|
2260
|
+
|
2261
|
+
sigma_max = promote(gg.norm())
|
2262
|
+
|
2263
|
+
trace_gg = promote(torch.trace(gg))
|
2264
|
+
sigma_min_approx = trace_gg / (n * sigma_max)
|
2265
|
+
|
2266
|
+
return sigma_max, sigma_min_approx
|
2267
|
+
|
2268
|
+
|
2269
|
+
@decorator_knowngood
|
2270
|
+
def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
|
2271
|
+
k = torch.arange(degree, dtype=torch.float64, device=device)
|
2272
|
+
rotation = (2 * k + 1) * math.pi / (2 * degree)
|
2273
|
+
f = (rotation.cos() + 1 + eps) ** -0.5
|
2274
|
+
rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
|
2275
|
+
coeff0 = f.sum() / degree
|
2276
|
+
coeffs = f @ rotation * 2 / degree
|
2277
|
+
return coeff0.float(), coeffs.float()
|
2278
|
+
|
2279
|
+
|
2280
|
+
@decorator_knowngood
|
2281
|
+
def _psgd_default_preconditioner_grad(
|
2282
|
+
terms: List[Tuple[Tensor, Tensor]],
|
2283
|
+
Q: List[Tensor],
|
2284
|
+
) -> List[Tensor]:
|
2285
|
+
out = []
|
2286
|
+
for q, (x, y) in zip(Q, terms):
|
2287
|
+
x = promote(x)
|
2288
|
+
y = promote(y)
|
2289
|
+
update = x - y
|
2290
|
+
if q.ndim < 2:
|
2291
|
+
update = q * update
|
2292
|
+
else:
|
2293
|
+
update = (q @ update).triu()
|
2294
|
+
out.append(update)
|
2295
|
+
return out
|
2296
|
+
|
2297
|
+
|
2030
2298
|
@decorator
|
2031
2299
|
def psgd_update_precond(
|
2032
2300
|
G: Tensor,
|
@@ -2056,6 +2324,102 @@ def psgd_update_precond(
|
|
2056
2324
|
return None
|
2057
2325
|
|
2058
2326
|
|
2327
|
+
@decorator_knowngood
|
2328
|
+
def bf16_matmul(x: Tensor, y: Tensor):
|
2329
|
+
return (promote(x) @ promote(y)).to(x.dtype)
|
2330
|
+
|
2331
|
+
|
2332
|
+
def if_iscompiling(fn):
|
2333
|
+
base = getattr(torch, fn.__name__, None)
|
2334
|
+
|
2335
|
+
def _fn(x):
|
2336
|
+
if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
|
2337
|
+
return base(x)
|
2338
|
+
return fn(x)
|
2339
|
+
|
2340
|
+
return _fn
|
2341
|
+
|
2342
|
+
|
2343
|
+
@if_iscompiling
|
2344
|
+
def while_loop(cond, body, state):
|
2345
|
+
"""
|
2346
|
+
dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
|
2347
|
+
useful for debugging
|
2348
|
+
"""
|
2349
|
+
while cond(*state).item():
|
2350
|
+
state = body(*state)
|
2351
|
+
return state
|
2352
|
+
|
2353
|
+
|
2354
|
+
@if_iscompiling
|
2355
|
+
def cond(cond, true_fn, false_fn):
|
2356
|
+
"""
|
2357
|
+
dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
|
2358
|
+
useful for debugging
|
2359
|
+
"""
|
2360
|
+
|
2361
|
+
if cond.item():
|
2362
|
+
return true_fn()
|
2363
|
+
return false_fn()
|
2364
|
+
|
2365
|
+
|
2366
|
+
@decorator_knowngood
|
2367
|
+
def _householder_vec_e1_to_v(v: Tensor, eps: float = 1e-12) -> Tensor:
|
2368
|
+
"""
|
2369
|
+
Return w such that H = I - 2 w w^T is orthogonal and H e1 = v (v unit).
|
2370
|
+
Applying from the right: G @ H = G - 2 (G @ w) w^T.
|
2371
|
+
If v is (numerically) e1, returns w=0 and H=I.
|
2372
|
+
"""
|
2373
|
+
v = v / v.norm().clamp(min=eps)
|
2374
|
+
e1 = torch.zeros_like(v)
|
2375
|
+
e1[0] = 1.0
|
2376
|
+
w = e1 - v
|
2377
|
+
return w / w.norm().clamp(min=eps)
|
2378
|
+
|
2379
|
+
|
2380
|
+
@decorator_knowngood
|
2381
|
+
def eigvecs_product_rank1(
|
2382
|
+
G: Tensor, v: Tensor, w: Optional[Tensor] = None, eps: float = 1e-12
|
2383
|
+
) -> Tuple[Tensor, Tensor]:
|
2384
|
+
"""
|
2385
|
+
Compute Y = G @ V where V is an eigenvector matrix for P = λ I + σ v v^T,
|
2386
|
+
using the Householder reflector with first column v. Never materializes V.
|
2387
|
+
|
2388
|
+
Args:
|
2389
|
+
G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
|
2390
|
+
v: shape (d,) — current unit direction (top eigenvector of P).
|
2391
|
+
w: optional Householder vector w; pass to reuse across calls.
|
2392
|
+
|
2393
|
+
Returns:
|
2394
|
+
(Y, w) where:
|
2395
|
+
Y has shape (..., d) and equals G @ eigenvectors(P),
|
2396
|
+
w is the Householder vector you can cache & reuse.
|
2397
|
+
"""
|
2398
|
+
if w is None:
|
2399
|
+
w = _householder_vec_e1_to_v(v, eps)
|
2400
|
+
Y = G - 2.0 * compiled_einsum("...i,i,j->...j", G, w, w)
|
2401
|
+
return Y, w
|
2402
|
+
|
2403
|
+
|
2404
|
+
@decorator_knowngood
|
2405
|
+
def oja_update(v: Tensor, g: Tensor, lr: float = 1e-2, eps: float = 1e-12) -> Tensor:
|
2406
|
+
"""
|
2407
|
+
One Oja step to track the top eigendirection of the gradient covariance.
|
2408
|
+
v <- v + lr * ((g^T v) g - (g^T v)^2 v); then renormalize.
|
2409
|
+
"""
|
2410
|
+
gv = g @ v
|
2411
|
+
v = v + lr * (gv * g - (gv * gv) * v)
|
2412
|
+
return v / v.norm().clamp(min=eps)
|
2413
|
+
|
2414
|
+
|
2415
|
+
def cond_n(cond_val: Tensor, *fns):
|
2416
|
+
fns = list(fns)
|
2417
|
+
fn = fns.pop(0)
|
2418
|
+
if not fns:
|
2419
|
+
return fn
|
2420
|
+
return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
|
2421
|
+
|
2422
|
+
|
2059
2423
|
@decorator_knowngood
|
2060
2424
|
def _psgd_precond_update_(
|
2061
2425
|
matmuled: List[Optional[Tensor]],
|
@@ -2074,7 +2438,7 @@ def _psgd_precond_update_(
|
|
2074
2438
|
if update.ndim < 2:
|
2075
2439
|
lb = update.norm(float("inf"))
|
2076
2440
|
else:
|
2077
|
-
lb = max_singular_value(update,
|
2441
|
+
lb = max_singular_value(update, power_iter=power_iter)
|
2078
2442
|
update = promote(update)
|
2079
2443
|
if store_triu_as_line:
|
2080
2444
|
update = triu_to_line([update])[0][1]
|
@@ -2146,70 +2510,83 @@ def inverse_free_psgd_update_precond(
|
|
2146
2510
|
|
2147
2511
|
|
2148
2512
|
@decorator_knowngood
|
2149
|
-
def
|
2150
|
-
|
2151
|
-
x =
|
2152
|
-
norm =
|
2153
|
-
|
2154
|
-
|
2155
|
-
|
2513
|
+
def _clip(x, norm, clip_at, eps=1e-8):
|
2514
|
+
x32 = promote(x)
|
2515
|
+
# (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
|
2516
|
+
norm = clip_at / norm.clamp(min=max(clip_at, eps))
|
2517
|
+
x32 = x32 * norm
|
2518
|
+
copy_stochastic_(x, x32)
|
2519
|
+
|
2520
|
+
|
2521
|
+
@decorator_knowngood
|
2522
|
+
def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
|
2523
|
+
for x in xs:
|
2524
|
+
_clip(x, promote(x).norm(), clip_at, eps)
|
2156
2525
|
|
2157
2526
|
|
2158
2527
|
def l2_normalization_(x, clip_at: float = 1e-8):
|
2159
2528
|
x = list_guard(x)
|
2160
|
-
|
2529
|
+
_compilable_l2_clip_(x, clip_at)
|
2530
|
+
return x
|
2161
2531
|
|
2162
2532
|
|
2163
2533
|
def l2_clip_(x, clip_at: float = 1.0):
|
2164
2534
|
x = list_guard(x)
|
2165
|
-
|
2535
|
+
_compilable_l2_clip_(x, clip_at)
|
2536
|
+
return x
|
2166
2537
|
|
2167
2538
|
|
2168
2539
|
@decorator_knowngood
|
2169
|
-
def _compilable_rmsnorm_clip_(
|
2170
|
-
x
|
2171
|
-
|
2172
|
-
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
2173
|
-
torch._foreach_maximum_(norm, clip_at)
|
2174
|
-
return torch._foreach_div(x, norm)
|
2540
|
+
def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
|
2541
|
+
for x in xs:
|
2542
|
+
_clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
|
2175
2543
|
|
2176
2544
|
|
2177
2545
|
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
2178
2546
|
x = list_guard(x)
|
2179
|
-
|
2180
|
-
|
2181
|
-
|
2182
|
-
@decorator_knowngood
|
2183
|
-
def _compilable_global_rmsnorm_clip_(x, clip_at):
|
2184
|
-
x = list(map(promote, x))
|
2185
|
-
norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
|
2186
|
-
norm = norm**0.5
|
2187
|
-
norm = norm.clamp(min=clip_at)
|
2188
|
-
return torch._foreach_div(x, norm)
|
2547
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2548
|
+
return x
|
2189
2549
|
|
2190
2550
|
|
2191
2551
|
@decorator_knowngood
|
2192
|
-
def
|
2193
|
-
|
2194
|
-
|
2195
|
-
|
2196
|
-
|
2197
|
-
|
2552
|
+
def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
|
2553
|
+
norm = 0
|
2554
|
+
numel = sum([i.numel() for i in x])
|
2555
|
+
for i in x:
|
2556
|
+
norm += promote(i).square().sum()
|
2557
|
+
norm = (norm / numel) ** 0.5
|
2558
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2559
|
+
stochastic_multiply_(x, scalar)
|
2198
2560
|
|
2199
2561
|
|
2200
2562
|
def global_rmsnorm_clip(x, clip_at: float = 1.0):
|
2201
2563
|
x = list_guard(x)
|
2202
|
-
|
2564
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2565
|
+
_compilable_global_rmsnorm_clip_(x, clip_at)
|
2566
|
+
return x
|
2567
|
+
|
2568
|
+
|
2569
|
+
@decorator_knowngood
|
2570
|
+
def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
|
2571
|
+
norm = 0
|
2572
|
+
for i in x:
|
2573
|
+
norm += promote(i).square().sum()
|
2574
|
+
norm = norm**0.5
|
2575
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2576
|
+
stochastic_multiply_(x, scalar)
|
2203
2577
|
|
2204
2578
|
|
2205
2579
|
def global_l2norm_clip(x, clip_at: float = 1.0):
|
2206
2580
|
x = list_guard(x)
|
2207
|
-
|
2581
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2582
|
+
_compilable_global_l2norm_clip_(x, clip_at)
|
2583
|
+
return x
|
2208
2584
|
|
2209
2585
|
|
2210
2586
|
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
2211
2587
|
x = list_guard(x)
|
2212
|
-
|
2588
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2589
|
+
return x
|
2213
2590
|
|
2214
2591
|
|
2215
2592
|
@decorator_knowngood
|
@@ -2284,17 +2661,6 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
|
2284
2661
|
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
2285
2662
|
|
2286
2663
|
|
2287
|
-
@decorator_knowngood
|
2288
|
-
def _compilable_weight_decay_to_init_(p, init, weight_decay):
|
2289
|
-
_lerp(p, promote(init), 1 - weight_decay)
|
2290
|
-
|
2291
|
-
|
2292
|
-
def weight_decay_to_init_(p, init, weight_decay):
|
2293
|
-
p, init = list_guard(p, init)
|
2294
|
-
weight_decay = scalar_guard(weight_decay, p[0])
|
2295
|
-
_compilable_weight_decay_to_ema_(p, init, weight_decay)
|
2296
|
-
|
2297
|
-
|
2298
2664
|
@decorator_knowngood
|
2299
2665
|
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
2300
2666
|
ema32 = _lerp(ema, p, ema_decay)
|
@@ -2416,7 +2782,7 @@ def precond_grad_cached_(
|
|
2416
2782
|
md = min_dtype(list(cached_q) + [ea])
|
2417
2783
|
args = [q.to(md) for q in cached_q]
|
2418
2784
|
args = args + [ea.to(md)]
|
2419
|
-
expr = cached_precond_grad_expr(ndim_tuple(cached_q),
|
2785
|
+
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
|
2420
2786
|
new = compiled_einsum(expr, *args)
|
2421
2787
|
if cast:
|
2422
2788
|
return new.to(ea.dtype)
|
@@ -2433,7 +2799,7 @@ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, c
|
|
2433
2799
|
|
2434
2800
|
|
2435
2801
|
def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2436
|
-
lr = scalar_guard(lr, param[0])
|
2802
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2437
2803
|
_compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
|
2438
2804
|
|
2439
2805
|
|
@@ -2502,7 +2868,7 @@ def fused_psgd_precond_grad(
|
|
2502
2868
|
store_triu_as_line: bool = False,
|
2503
2869
|
symmetric_output: bool = False,
|
2504
2870
|
):
|
2505
|
-
lr = scalar_guard(lr, param[0])
|
2871
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2506
2872
|
_compilable_fused_psgd_precond_grad(
|
2507
2873
|
ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
|
2508
2874
|
)
|