heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +168 -29
- heavyball/chainable.py +165 -63
- heavyball/helpers.py +5 -1
- heavyball/utils.py +507 -124
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/METADATA +19 -7
- heavyball-2.1.1.dist-info/RECORD +9 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/WHEEL +1 -1
- heavyball-2.0.0.dev0.dist-info/RECORD +0 -9
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import contextlib
|
3
|
+
import enum
|
3
4
|
import functools
|
4
5
|
import gc
|
5
6
|
import inspect
|
@@ -16,13 +17,29 @@ import torch
|
|
16
17
|
from torch import Tensor
|
17
18
|
from torch._dynamo.exc import TorchDynamoException
|
18
19
|
from torch.backends import cudnn, opt_einsum
|
20
|
+
from torch.nn import functional as F
|
19
21
|
from torch.utils._pytree import tree_map
|
20
22
|
|
23
|
+
|
24
|
+
class ZerothPowerMode(enum.Enum):
|
25
|
+
newtonschulz = "newtonschulz"
|
26
|
+
legacy_newtonschulz = "legacy_newtonschulz"
|
27
|
+
qr = "qr"
|
28
|
+
svd = "svd"
|
29
|
+
legacy_svd = "legacy_svd"
|
30
|
+
|
31
|
+
|
32
|
+
class OrthoScaleMode(enum.Enum):
|
33
|
+
none = "none"
|
34
|
+
scale = "scale"
|
35
|
+
graft = "graft"
|
36
|
+
|
37
|
+
|
21
38
|
compile_mode = "max-autotune-no-cudagraphs"
|
22
39
|
dynamic = False
|
23
40
|
compile_mode_recommended_to_none = None
|
24
41
|
zeroth_power_mode = "newtonschulz"
|
25
|
-
precise_zeroth_power_mode = "qr"
|
42
|
+
precise_zeroth_power_mode = "qr"
|
26
43
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
44
|
_cudnn_double_backward_pattern = re.compile(
|
28
45
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -240,14 +257,14 @@ def eps_sqrt(item, eps):
|
|
240
257
|
|
241
258
|
@decorator_knowngood
|
242
259
|
def _compilable_exp_avg_sq_(
|
243
|
-
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[
|
260
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
|
244
261
|
):
|
245
262
|
g32 = promote(grad)
|
246
263
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
247
264
|
|
248
265
|
denom = [eps_sqrt(d, eps) for d in s32]
|
249
266
|
|
250
|
-
if out[0] is None:
|
267
|
+
if out is None or out[0] is None:
|
251
268
|
return denom
|
252
269
|
|
253
270
|
copy_stochastic_list_(out, denom)
|
@@ -316,8 +333,8 @@ def adaptive_gradient_clipping_(
|
|
316
333
|
def is_compiling():
|
317
334
|
try:
|
318
335
|
return torch.compiler.is_compiling()
|
319
|
-
except TorchDynamoException:
|
320
|
-
return
|
336
|
+
except (TorchDynamoException, AttributeError):
|
337
|
+
return False
|
321
338
|
|
322
339
|
|
323
340
|
def set_(dst: Tensor, src: Tensor):
|
@@ -366,12 +383,45 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
|
|
366
383
|
)
|
367
384
|
|
368
385
|
|
369
|
-
@
|
386
|
+
@decorator_knowngood
|
370
387
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
388
|
+
assert (
|
389
|
+
G.ndim >= 2
|
390
|
+
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
391
|
+
assert steps == 5
|
392
|
+
G = G.clone()
|
393
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
394
|
+
if G.size(-2) > G.size(-1):
|
395
|
+
X = X.mT
|
396
|
+
|
397
|
+
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
398
|
+
stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
399
|
+
# Perform the NS iterations
|
400
|
+
for a, b, c in [
|
401
|
+
(4.0848, -6.8946, 2.9270),
|
402
|
+
(3.9505, -6.3029, 2.6377),
|
403
|
+
(3.7418, -5.5913, 2.3037),
|
404
|
+
(2.8769, -3.1427, 1.2046),
|
405
|
+
(2.8366, -3.0525, 1.2012),
|
406
|
+
]:
|
407
|
+
A = X @ X.mT
|
408
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
409
|
+
X = a * X + B @ X
|
410
|
+
|
411
|
+
if G.size(-2) > G.size(-1):
|
412
|
+
X = X.mT
|
413
|
+
return X.to(G.dtype)
|
414
|
+
|
415
|
+
|
416
|
+
@decorator_knowngood
|
417
|
+
def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
371
418
|
assert len(G.shape) == 2
|
372
419
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
373
|
-
|
374
|
-
X
|
420
|
+
G = G.clone()
|
421
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
422
|
+
|
423
|
+
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
424
|
+
stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
375
425
|
if G.size(0) > G.size(1):
|
376
426
|
X = X.T
|
377
427
|
for _ in range(steps):
|
@@ -435,21 +485,30 @@ def _compilable_grafting(magnitude, direction):
|
|
435
485
|
|
436
486
|
|
437
487
|
@decorator_knowngood
|
438
|
-
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
439
|
-
if mode
|
488
|
+
def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
|
489
|
+
if not isinstance(mode, ZerothPowerMode):
|
490
|
+
mode = ZerothPowerMode(mode)
|
491
|
+
if not isinstance(scale_mode, ZerothPowerMode):
|
492
|
+
scale_mode = OrthoScaleMode(scale_mode)
|
493
|
+
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
|
440
494
|
y = zeropower_via_newtonschulz5(x, 5)
|
441
|
-
elif mode ==
|
495
|
+
elif mode == ZerothPowerMode.legacy_newtonschulz:
|
496
|
+
y = legacy_zeropower_via_newtonschulz5(x, 5)
|
497
|
+
elif mode == ZerothPowerMode.qr:
|
442
498
|
y = torch.linalg.qr(promote(x)).Q
|
443
|
-
elif mode ==
|
444
|
-
u, _s,
|
445
|
-
y = u @
|
499
|
+
elif mode == ZerothPowerMode.svd:
|
500
|
+
u, _s, vt = torch.linalg.svd(promote(x))
|
501
|
+
y = u @ vt
|
502
|
+
elif mode == ZerothPowerMode.legacy_svd:
|
503
|
+
u, _s, vt = torch.linalg.svd(promote(x))
|
504
|
+
y = u @ vt.T
|
446
505
|
else:
|
447
506
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
448
|
-
if scale_mode ==
|
507
|
+
if scale_mode == OrthoScaleMode.none:
|
449
508
|
pass
|
450
|
-
elif scale_mode ==
|
451
|
-
y *= max(1, x.size(
|
452
|
-
elif scale_mode ==
|
509
|
+
elif scale_mode == OrthoScaleMode.scale:
|
510
|
+
y *= max(1, x.size(-2) / x.size(-1)) ** 0.5
|
511
|
+
elif scale_mode == OrthoScaleMode.graft:
|
453
512
|
y = _compilable_grafting(x, y)
|
454
513
|
else:
|
455
514
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
@@ -556,10 +615,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
556
615
|
except torch.OutOfMemoryError:
|
557
616
|
if m.device.type == "cpu":
|
558
617
|
raise
|
559
|
-
|
618
|
+
if torch.cuda.is_available():
|
619
|
+
torch.cuda.synchronize(m.device)
|
620
|
+
clean()
|
621
|
+
m = m.cpu()
|
622
|
+
except RuntimeError as e:
|
623
|
+
if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
|
624
|
+
torch.cuda.synchronize(m.device)
|
625
|
+
clean()
|
560
626
|
m = m.cpu()
|
561
|
-
|
562
|
-
if m.dtype != torch.double:
|
627
|
+
elif m.dtype != torch.double:
|
563
628
|
m = m.double()
|
564
629
|
elif eps < max_eps:
|
565
630
|
eps = eps ** (2 / 3)
|
@@ -658,7 +723,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
658
723
|
copy_stochastic_(x_, x32 + y32 * alpha)
|
659
724
|
|
660
725
|
|
661
|
-
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
726
|
+
def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
|
662
727
|
x, y = broadcastable_list_guard(x, y)
|
663
728
|
alpha = scalar_guard(alpha, x[0])
|
664
729
|
_compilable_stochastic_add_(x, y, alpha)
|
@@ -672,7 +737,9 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
|
|
672
737
|
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
673
738
|
|
674
739
|
|
675
|
-
def stochastic_add_divide_(
|
740
|
+
def stochastic_add_divide_(
|
741
|
+
x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
|
742
|
+
):
|
676
743
|
x, y = broadcastable_list_guard(x, y)
|
677
744
|
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
678
745
|
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
@@ -686,11 +753,25 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
686
753
|
copy_stochastic_(x_, x32 * y32)
|
687
754
|
|
688
755
|
|
689
|
-
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
756
|
+
def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
|
690
757
|
x, y = broadcastable_list_guard(x, y)
|
691
758
|
_compilable_stochastic_multiply_(x, y)
|
692
759
|
|
693
760
|
|
761
|
+
@decorator_knowngood
|
762
|
+
def _compilable_stochastic_divide_with_eps_(x: List[Tensor], y: List[Tensor], eps: Tensor):
|
763
|
+
for x_, y_ in zip(x, y):
|
764
|
+
x32 = promote(x_)
|
765
|
+
y32 = promote(y_)
|
766
|
+
copy_stochastic_(x_, x32 / (y32 + eps))
|
767
|
+
|
768
|
+
|
769
|
+
def stochastic_divide_with_eps_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, eps: float):
|
770
|
+
x, y = broadcastable_list_guard(x, y)
|
771
|
+
eps = scalar_guard(eps, y[0])
|
772
|
+
_compilable_stochastic_divide_with_eps_(x, y, eps)
|
773
|
+
|
774
|
+
|
694
775
|
@decorator
|
695
776
|
def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
696
777
|
"""
|
@@ -832,6 +913,10 @@ class ExactHVPFailed(ValueError):
|
|
832
913
|
use_default = object()
|
833
914
|
|
834
915
|
|
916
|
+
def _tensor_key(x: Tensor):
|
917
|
+
return x.data_ptr(), x.numel(), x.dtype, x.device
|
918
|
+
|
919
|
+
|
835
920
|
class StatefulOptimizer(torch.optim.Optimizer):
|
836
921
|
"""
|
837
922
|
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
@@ -874,7 +959,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
874
959
|
|
875
960
|
self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
|
876
961
|
self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
|
877
|
-
self._init_mapping()
|
878
962
|
|
879
963
|
def _store_stats(self, state_dict: dict[str, any]):
|
880
964
|
state_dict["heavyball"] = {
|
@@ -905,7 +989,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
905
989
|
def state_(self, arg: Tensor, fail: bool = True):
|
906
990
|
if not fail and arg not in self.mapping:
|
907
991
|
return {}
|
908
|
-
|
992
|
+
if _tensor_key(arg) not in self.mapping_inverse:
|
993
|
+
self._init_mapping()
|
994
|
+
state_param, index = self.mapping_inverse[_tensor_key(arg)]
|
909
995
|
if state_param not in self.state:
|
910
996
|
self.state[state_param] = collections.defaultdict(dict)
|
911
997
|
return self.state[state_param][index]
|
@@ -928,7 +1014,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
928
1014
|
if p not in self.mapping:
|
929
1015
|
self.mapping[p] = p_views = merge_group(group, p)
|
930
1016
|
for i, pv in enumerate(p_views):
|
931
|
-
self.mapping_inverse[pv] = (p, i)
|
1017
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
932
1018
|
|
933
1019
|
def split_p_and_g_in_group(
|
934
1020
|
self,
|
@@ -949,12 +1035,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
949
1035
|
yield p, grad
|
950
1036
|
continue
|
951
1037
|
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
self.mapping[p] = p_views = merge_group(group, p)
|
956
|
-
for i, pv in enumerate(p_views):
|
957
|
-
self.mapping_inverse[pv] = (p, i)
|
1038
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
1039
|
+
for i, pv in enumerate(p_views):
|
1040
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
958
1041
|
|
959
1042
|
vector = getattr(p, "vector", None)
|
960
1043
|
hessian_vector = getattr(p, "hessian_vector", None)
|
@@ -1199,17 +1282,53 @@ def _compilable_adam_(
|
|
1199
1282
|
|
1200
1283
|
|
1201
1284
|
def adam_(
|
1285
|
+
exp_avg: List[Tensor] | Tensor,
|
1286
|
+
exp_avg_sq: List[Tensor] | Tensor,
|
1287
|
+
grad: List[Tensor] | Tensor,
|
1288
|
+
beta1: float,
|
1289
|
+
beta2: float,
|
1290
|
+
step: int,
|
1291
|
+
eps: float = 1e-8,
|
1292
|
+
) -> List[Tensor]:
|
1293
|
+
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
1294
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1295
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
1296
|
+
return grad
|
1297
|
+
|
1298
|
+
|
1299
|
+
@decorator_knowngood
|
1300
|
+
def _compilable_unscaled_adam_(
|
1202
1301
|
exp_avg: List[Tensor],
|
1203
1302
|
exp_avg_sq: List[Tensor],
|
1204
1303
|
grad: List[Tensor],
|
1304
|
+
beta1: Tensor,
|
1305
|
+
beta2: Tensor,
|
1306
|
+
step: Tensor,
|
1307
|
+
eps: Tensor,
|
1308
|
+
):
|
1309
|
+
beta1 = beta_debias(beta1, step)
|
1310
|
+
beta2 = beta_debias(beta2, step)
|
1311
|
+
|
1312
|
+
g32 = list(map(promote, grad))
|
1313
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
1314
|
+
g32 = torch._foreach_div(g32, denom)
|
1315
|
+
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
1316
|
+
u32 = torch._foreach_mul(exp_avg32, denom)
|
1317
|
+
copy_stochastic_list_(grad, u32)
|
1318
|
+
|
1319
|
+
|
1320
|
+
def unscaled_adam_(
|
1321
|
+
exp_avg: List[Tensor] | Tensor,
|
1322
|
+
exp_avg_sq: List[Tensor] | Tensor,
|
1323
|
+
grad: List[Tensor] | Tensor,
|
1205
1324
|
beta1: float,
|
1206
1325
|
beta2: float,
|
1207
1326
|
step: int,
|
1208
1327
|
eps: float = 1e-8,
|
1209
|
-
):
|
1328
|
+
) -> List[Tensor]:
|
1210
1329
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
1211
1330
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1212
|
-
|
1331
|
+
_compilable_unscaled_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
1213
1332
|
return grad
|
1214
1333
|
|
1215
1334
|
|
@@ -1253,7 +1372,7 @@ def fused_adam_(
|
|
1253
1372
|
caution: bool,
|
1254
1373
|
):
|
1255
1374
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
1256
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
1375
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
|
1257
1376
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
1258
1377
|
|
1259
1378
|
|
@@ -1332,7 +1451,7 @@ def fused_laprop_(
|
|
1332
1451
|
eps: float = 1e-8,
|
1333
1452
|
):
|
1334
1453
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1335
|
-
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
1454
|
+
beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
|
1336
1455
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
1337
1456
|
|
1338
1457
|
|
@@ -1351,7 +1470,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
1351
1470
|
|
1352
1471
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
1353
1472
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1354
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
1473
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
|
1355
1474
|
_fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
1356
1475
|
|
1357
1476
|
|
@@ -1381,11 +1500,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
|
1381
1500
|
|
1382
1501
|
|
1383
1502
|
@decorator_knowngood
|
1384
|
-
def stochastic_round_(ref: Tensor, source: Tensor):
|
1385
|
-
if source
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1503
|
+
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
|
1504
|
+
if source is not None:
|
1505
|
+
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
1506
|
+
return source
|
1507
|
+
if ref.dtype != torch.bfloat16:
|
1508
|
+
return source.to(ref.dtype)
|
1509
|
+
else:
|
1510
|
+
source = ref
|
1511
|
+
source = source.float()
|
1389
1512
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
1390
1513
|
result.add_(source.view(dtype=torch.int32))
|
1391
1514
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
@@ -1908,7 +2031,9 @@ def ndim_tuple(Q: list[Tensor]) -> tuple:
|
|
1908
2031
|
return tuple(q.ndim for q in Q)
|
1909
2032
|
|
1910
2033
|
|
1911
|
-
def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
2034
|
+
def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
|
2035
|
+
if conjB is None:
|
2036
|
+
conjB = torch.randn_like(G)
|
1912
2037
|
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
|
1913
2038
|
A = casted_einsum(exprA, *Q, G)
|
1914
2039
|
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
@@ -1940,24 +2065,35 @@ def max_singular_value_exact(A, use_lobpcg: bool = False):
|
|
1940
2065
|
eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
|
1941
2066
|
return eigval[0].sqrt()
|
1942
2067
|
else:
|
1943
|
-
return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
|
1944
|
-
except torch.linalg.LinAlgError:
|
1945
|
-
return
|
2068
|
+
return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
|
2069
|
+
except (torch.linalg.LinAlgError, RuntimeError):
|
2070
|
+
return max_singular_value_power_iter(promote(A), iterations=2)
|
1946
2071
|
|
1947
2072
|
|
1948
2073
|
@decorator_knowngood
|
1949
|
-
def max_singular_value_power_iter(
|
2074
|
+
def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
|
1950
2075
|
"""
|
1951
2076
|
Rayleigh quotient of row with the largest norm + optional power iterations
|
1952
2077
|
"""
|
1953
|
-
x_norm, max_idx =
|
1954
|
-
|
1955
|
-
|
1956
|
-
|
1957
|
-
|
1958
|
-
x = A.
|
1959
|
-
|
1960
|
-
|
2078
|
+
x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
|
2079
|
+
x_norm = promote(x_norm)
|
2080
|
+
|
2081
|
+
def _inner():
|
2082
|
+
A = A_outer
|
2083
|
+
x = A.index_select(0, max_idx).flatten().contiguous()
|
2084
|
+
A = stochastic_round_(A / x_norm)
|
2085
|
+
x = x / x_norm
|
2086
|
+
|
2087
|
+
def _mv(x):
|
2088
|
+
return promote(A.T.mv(A.mv(stochastic_round_(x))))
|
2089
|
+
|
2090
|
+
for _ in range(iterations):
|
2091
|
+
# A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
|
2092
|
+
x = F.normalize(_mv(x), dim=0)
|
2093
|
+
out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
|
2094
|
+
return out.squeeze().clone()
|
2095
|
+
|
2096
|
+
return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
|
1961
2097
|
|
1962
2098
|
|
1963
2099
|
@decorator_knowngood
|
@@ -1974,33 +2110,81 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
|
1974
2110
|
return sketch_norm * max_abs
|
1975
2111
|
|
1976
2112
|
|
2113
|
+
def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2114
|
+
if A.ndim <= 2:
|
2115
|
+
return max_singular_value(A, max_svd, use_cholesky, power_iter)
|
2116
|
+
|
2117
|
+
base = einsum_base[: A.ndim]
|
2118
|
+
A16 = stochastic_round_(A)
|
2119
|
+
squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
|
2120
|
+
svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
|
2121
|
+
svds = torch.stack(svds)
|
2122
|
+
return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
|
2123
|
+
|
2124
|
+
|
1977
2125
|
@decorator_knowngood
|
1978
|
-
def max_singular_value(
|
1979
|
-
A
|
1980
|
-
)
|
2126
|
+
def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2127
|
+
if A.ndim < 2:
|
2128
|
+
return A.abs().max()
|
2129
|
+
if A.ndim > 2:
|
2130
|
+
raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
|
1981
2131
|
if min(A.shape) <= max_svd:
|
1982
2132
|
return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
|
1983
2133
|
if use_cholesky or power_iter < 0:
|
1984
|
-
return max_singular_value_cholesky(A
|
2134
|
+
return max_singular_value_cholesky(A)
|
1985
2135
|
return max_singular_value_power_iter(A, None, iterations=power_iter)
|
1986
2136
|
|
1987
2137
|
|
1988
2138
|
@decorator_knowngood
|
1989
|
-
def
|
1990
|
-
|
1991
|
-
|
1992
|
-
)
|
1993
|
-
out =
|
1994
|
-
|
1995
|
-
|
1996
|
-
|
1997
|
-
|
1998
|
-
|
1999
|
-
|
2000
|
-
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2139
|
+
def clamped_max_singular_value(
|
2140
|
+
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
|
2141
|
+
) -> Tensor:
|
2142
|
+
norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
|
2143
|
+
out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
|
2144
|
+
return out.clamp(min=min)
|
2145
|
+
|
2146
|
+
|
2147
|
+
@decorator_knowngood
|
2148
|
+
def min_singular_value(
|
2149
|
+
A: Tensor,
|
2150
|
+
power_iter: int = 5,
|
2151
|
+
safety: float = 1.05,
|
2152
|
+
max_svd: int = 32,
|
2153
|
+
):
|
2154
|
+
if A.ndim < 2:
|
2155
|
+
return A.abs().min()
|
2156
|
+
|
2157
|
+
n = A.size(0)
|
2158
|
+
if n <= max_svd:
|
2159
|
+
try:
|
2160
|
+
eigs = torch.linalg.eigvalsh(promote(A))
|
2161
|
+
return eigs.min().to(A.dtype)
|
2162
|
+
except torch.linalg.LinAlgError:
|
2163
|
+
pass
|
2164
|
+
|
2165
|
+
lambda_max_hat = max_singular_value(A, power_iter=power_iter)
|
2166
|
+
lambda_upper = lambda_max_hat * safety
|
2167
|
+
|
2168
|
+
row_norms = A.norm(dim=1)
|
2169
|
+
norm, idx = row_norms.min(dim=0)
|
2170
|
+
v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
|
2171
|
+
|
2172
|
+
v = v / promote(v.norm())
|
2173
|
+
for _ in range(power_iter):
|
2174
|
+
v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
|
2175
|
+
v = v / promote(v.norm())
|
2176
|
+
mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
|
2177
|
+
|
2178
|
+
lambda_min_hat = lambda_upper - mu_hat
|
2179
|
+
|
2180
|
+
def _approx():
|
2181
|
+
mu = A.trace() / n
|
2182
|
+
sigma_square = A.square().sum() / n - mu**2
|
2183
|
+
return mu - (sigma_square / (n - 1)).sqrt()
|
2184
|
+
|
2185
|
+
return cond(
|
2186
|
+
(~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
|
2187
|
+
).squeeze()
|
2004
2188
|
|
2005
2189
|
|
2006
2190
|
@decorator_knowngood
|
@@ -2027,6 +2211,107 @@ def calcG_expr(q_dim, g_dim):
|
|
2027
2211
|
return exprs
|
2028
2212
|
|
2029
2213
|
|
2214
|
+
def eye_like(x: Tensor):
|
2215
|
+
if x.ndim < 2:
|
2216
|
+
return torch.ones_like(x)
|
2217
|
+
assert x.ndim == 2
|
2218
|
+
assert x.size(0) == x.size(1)
|
2219
|
+
return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
|
2220
|
+
|
2221
|
+
|
2222
|
+
@decorator_knowngood
|
2223
|
+
def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
|
2224
|
+
"""
|
2225
|
+
Idea:
|
2226
|
+
G should be zeroth power. So, all Qs together should approximate the G's inverse.
|
2227
|
+
Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
|
2228
|
+
Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
|
2229
|
+
then backprop to L/R jointly.
|
2230
|
+
This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
|
2231
|
+
|
2232
|
+
`psgd_precond_grad` computes LGR for the general (n-dimensional) case
|
2233
|
+
`exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
|
2234
|
+
Args:
|
2235
|
+
G: Gradient that should be orthogonalized
|
2236
|
+
Q: List of preconditioner tensors.
|
2237
|
+
|
2238
|
+
Returns:
|
2239
|
+
- List of gradients with respect to Q (d_Q).
|
2240
|
+
"""
|
2241
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2242
|
+
|
2243
|
+
G16 = stochastic_round_(G)
|
2244
|
+
Q16 = [stochastic_round_(q) for q in Q]
|
2245
|
+
P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
|
2246
|
+
|
2247
|
+
d_P = torch.zeros_like(G)
|
2248
|
+
base = einsum_base[: G.ndim]
|
2249
|
+
for i, exprG in enumerate(exprGs):
|
2250
|
+
pp = compiled_einsum(exprG, P, P)
|
2251
|
+
error = pp - eye_like(pp)
|
2252
|
+
dim = einsum_base[i]
|
2253
|
+
if pp.ndim == 2:
|
2254
|
+
new = dim.upper()
|
2255
|
+
prec = f"{new}{dim}"
|
2256
|
+
else:
|
2257
|
+
new = dim
|
2258
|
+
prec = dim
|
2259
|
+
d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
|
2260
|
+
|
2261
|
+
d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
|
2262
|
+
grads = []
|
2263
|
+
for i, exprG in enumerate(exprGs):
|
2264
|
+
new_q = Q16[:]
|
2265
|
+
new_q[i] = eye_like(new_q[i])
|
2266
|
+
pq = psgd_precond_grad(G16, new_q)
|
2267
|
+
grad = compiled_einsum(exprG, pq, d_P)
|
2268
|
+
if grad.ndim == 2:
|
2269
|
+
grad = (grad + grad.T) / 2
|
2270
|
+
grads.append(grad)
|
2271
|
+
|
2272
|
+
return grads, P.to(G.dtype)
|
2273
|
+
|
2274
|
+
|
2275
|
+
def _inverse_initial_guess(gg):
|
2276
|
+
n = gg.shape[0]
|
2277
|
+
|
2278
|
+
sigma_max = promote(gg.norm())
|
2279
|
+
|
2280
|
+
trace_gg = promote(torch.trace(gg))
|
2281
|
+
sigma_min_approx = trace_gg / (n * sigma_max)
|
2282
|
+
|
2283
|
+
return sigma_max, sigma_min_approx
|
2284
|
+
|
2285
|
+
|
2286
|
+
@decorator_knowngood
|
2287
|
+
def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
|
2288
|
+
k = torch.arange(degree, dtype=torch.float64, device=device)
|
2289
|
+
rotation = (2 * k + 1) * math.pi / (2 * degree)
|
2290
|
+
f = (rotation.cos() + 1 + eps) ** -0.5
|
2291
|
+
rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
|
2292
|
+
coeff0 = f.sum() / degree
|
2293
|
+
coeffs = f @ rotation * 2 / degree
|
2294
|
+
return coeff0.float(), coeffs.float()
|
2295
|
+
|
2296
|
+
|
2297
|
+
@decorator_knowngood
|
2298
|
+
def _psgd_default_preconditioner_grad(
|
2299
|
+
terms: List[Tuple[Tensor, Tensor]],
|
2300
|
+
Q: List[Tensor],
|
2301
|
+
) -> List[Tensor]:
|
2302
|
+
out = []
|
2303
|
+
for q, (x, y) in zip(Q, terms):
|
2304
|
+
x = promote(x)
|
2305
|
+
y = promote(y)
|
2306
|
+
update = x - y
|
2307
|
+
if q.ndim < 2:
|
2308
|
+
update = q * update
|
2309
|
+
else:
|
2310
|
+
update = (q @ update).triu()
|
2311
|
+
out.append(update)
|
2312
|
+
return out
|
2313
|
+
|
2314
|
+
|
2030
2315
|
@decorator
|
2031
2316
|
def psgd_update_precond(
|
2032
2317
|
G: Tensor,
|
@@ -2056,6 +2341,102 @@ def psgd_update_precond(
|
|
2056
2341
|
return None
|
2057
2342
|
|
2058
2343
|
|
2344
|
+
@decorator_knowngood
|
2345
|
+
def bf16_matmul(x: Tensor, y: Tensor):
|
2346
|
+
return (promote(x) @ promote(y)).to(x.dtype)
|
2347
|
+
|
2348
|
+
|
2349
|
+
def if_iscompiling(fn):
|
2350
|
+
base = getattr(torch, fn.__name__, None)
|
2351
|
+
|
2352
|
+
def _fn(x):
|
2353
|
+
if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
|
2354
|
+
return base(x)
|
2355
|
+
return fn(x)
|
2356
|
+
|
2357
|
+
return _fn
|
2358
|
+
|
2359
|
+
|
2360
|
+
@if_iscompiling
|
2361
|
+
def while_loop(cond, body, state):
|
2362
|
+
"""
|
2363
|
+
dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
|
2364
|
+
useful for debugging
|
2365
|
+
"""
|
2366
|
+
while cond(*state).item():
|
2367
|
+
state = body(*state)
|
2368
|
+
return state
|
2369
|
+
|
2370
|
+
|
2371
|
+
@if_iscompiling
|
2372
|
+
def cond(cond, true_fn, false_fn):
|
2373
|
+
"""
|
2374
|
+
dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
|
2375
|
+
useful for debugging
|
2376
|
+
"""
|
2377
|
+
|
2378
|
+
if cond.item():
|
2379
|
+
return true_fn()
|
2380
|
+
return false_fn()
|
2381
|
+
|
2382
|
+
|
2383
|
+
@decorator_knowngood
|
2384
|
+
def _householder_vec_e1_to_v(v: Tensor, eps: float = 1e-12) -> Tensor:
|
2385
|
+
"""
|
2386
|
+
Return w such that H = I - 2 w w^T is orthogonal and H e1 = v (v unit).
|
2387
|
+
Applying from the right: G @ H = G - 2 (G @ w) w^T.
|
2388
|
+
If v is (numerically) e1, returns w=0 and H=I.
|
2389
|
+
"""
|
2390
|
+
v = v / v.norm().clamp(min=eps)
|
2391
|
+
e1 = torch.zeros_like(v)
|
2392
|
+
e1[0] = 1.0
|
2393
|
+
w = e1 - v
|
2394
|
+
return w / w.norm().clamp(min=eps)
|
2395
|
+
|
2396
|
+
|
2397
|
+
@decorator_knowngood
|
2398
|
+
def eigvecs_product_rank1(
|
2399
|
+
G: Tensor, v: Tensor, w: Optional[Tensor] = None, eps: float = 1e-12
|
2400
|
+
) -> Tuple[Tensor, Tensor]:
|
2401
|
+
"""
|
2402
|
+
Compute Y = G @ V where V is an eigenvector matrix for P = λ I + σ v v^T,
|
2403
|
+
using the Householder reflector with first column v. Never materializes V.
|
2404
|
+
|
2405
|
+
Args:
|
2406
|
+
G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
|
2407
|
+
v: shape (d,) — current unit direction (top eigenvector of P).
|
2408
|
+
w: optional Householder vector w; pass to reuse across calls.
|
2409
|
+
|
2410
|
+
Returns:
|
2411
|
+
(Y, w) where:
|
2412
|
+
Y has shape (..., d) and equals G @ eigenvectors(P),
|
2413
|
+
w is the Householder vector you can cache & reuse.
|
2414
|
+
"""
|
2415
|
+
if w is None:
|
2416
|
+
w = _householder_vec_e1_to_v(v, eps)
|
2417
|
+
Y = G - 2.0 * compiled_einsum("...i,i,j->...j", G, w, w)
|
2418
|
+
return Y, w
|
2419
|
+
|
2420
|
+
|
2421
|
+
@decorator_knowngood
|
2422
|
+
def oja_update(v: Tensor, g: Tensor, lr: float = 1e-2, eps: float = 1e-12) -> Tensor:
|
2423
|
+
"""
|
2424
|
+
One Oja step to track the top eigendirection of the gradient covariance.
|
2425
|
+
v <- v + lr * ((g^T v) g - (g^T v)^2 v); then renormalize.
|
2426
|
+
"""
|
2427
|
+
gv = g @ v
|
2428
|
+
v = v + lr * (gv * g - (gv * gv) * v)
|
2429
|
+
return v / v.norm().clamp(min=eps)
|
2430
|
+
|
2431
|
+
|
2432
|
+
def cond_n(cond_val: Tensor, *fns):
|
2433
|
+
fns = list(fns)
|
2434
|
+
fn = fns.pop(0)
|
2435
|
+
if not fns:
|
2436
|
+
return fn
|
2437
|
+
return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
|
2438
|
+
|
2439
|
+
|
2059
2440
|
@decorator_knowngood
|
2060
2441
|
def _psgd_precond_update_(
|
2061
2442
|
matmuled: List[Optional[Tensor]],
|
@@ -2074,7 +2455,7 @@ def _psgd_precond_update_(
|
|
2074
2455
|
if update.ndim < 2:
|
2075
2456
|
lb = update.norm(float("inf"))
|
2076
2457
|
else:
|
2077
|
-
lb = max_singular_value(update,
|
2458
|
+
lb = max_singular_value(update, power_iter=power_iter)
|
2078
2459
|
update = promote(update)
|
2079
2460
|
if store_triu_as_line:
|
2080
2461
|
update = triu_to_line([update])[0][1]
|
@@ -2146,70 +2527,83 @@ def inverse_free_psgd_update_precond(
|
|
2146
2527
|
|
2147
2528
|
|
2148
2529
|
@decorator_knowngood
|
2149
|
-
def
|
2150
|
-
|
2151
|
-
x =
|
2152
|
-
norm =
|
2153
|
-
|
2154
|
-
|
2155
|
-
|
2530
|
+
def _clip(x, norm, clip_at, eps=1e-8):
|
2531
|
+
x32 = promote(x)
|
2532
|
+
# (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
|
2533
|
+
norm = clip_at / norm.clamp(min=max(clip_at, eps))
|
2534
|
+
x32 = x32 * norm
|
2535
|
+
copy_stochastic_(x, x32)
|
2536
|
+
|
2537
|
+
|
2538
|
+
@decorator_knowngood
|
2539
|
+
def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
|
2540
|
+
for x in xs:
|
2541
|
+
_clip(x, promote(x).norm(), clip_at, eps)
|
2156
2542
|
|
2157
2543
|
|
2158
2544
|
def l2_normalization_(x, clip_at: float = 1e-8):
|
2159
2545
|
x = list_guard(x)
|
2160
|
-
|
2546
|
+
_compilable_l2_clip_(x, clip_at)
|
2547
|
+
return x
|
2161
2548
|
|
2162
2549
|
|
2163
2550
|
def l2_clip_(x, clip_at: float = 1.0):
|
2164
2551
|
x = list_guard(x)
|
2165
|
-
|
2552
|
+
_compilable_l2_clip_(x, clip_at)
|
2553
|
+
return x
|
2166
2554
|
|
2167
2555
|
|
2168
2556
|
@decorator_knowngood
|
2169
|
-
def _compilable_rmsnorm_clip_(
|
2170
|
-
x
|
2171
|
-
|
2172
|
-
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
2173
|
-
torch._foreach_maximum_(norm, clip_at)
|
2174
|
-
return torch._foreach_div(x, norm)
|
2557
|
+
def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
|
2558
|
+
for x in xs:
|
2559
|
+
_clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
|
2175
2560
|
|
2176
2561
|
|
2177
2562
|
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
2178
2563
|
x = list_guard(x)
|
2179
|
-
|
2180
|
-
|
2181
|
-
|
2182
|
-
@decorator_knowngood
|
2183
|
-
def _compilable_global_rmsnorm_clip_(x, clip_at):
|
2184
|
-
x = list(map(promote, x))
|
2185
|
-
norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
|
2186
|
-
norm = norm**0.5
|
2187
|
-
norm = norm.clamp(min=clip_at)
|
2188
|
-
return torch._foreach_div(x, norm)
|
2564
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2565
|
+
return x
|
2189
2566
|
|
2190
2567
|
|
2191
2568
|
@decorator_knowngood
|
2192
|
-
def
|
2193
|
-
|
2194
|
-
|
2195
|
-
|
2196
|
-
|
2197
|
-
|
2569
|
+
def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
|
2570
|
+
norm = 0
|
2571
|
+
numel = sum([i.numel() for i in x])
|
2572
|
+
for i in x:
|
2573
|
+
norm += promote(i).square().sum()
|
2574
|
+
norm = (norm / numel) ** 0.5
|
2575
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2576
|
+
stochastic_multiply_(x, scalar)
|
2198
2577
|
|
2199
2578
|
|
2200
2579
|
def global_rmsnorm_clip(x, clip_at: float = 1.0):
|
2201
2580
|
x = list_guard(x)
|
2202
|
-
|
2581
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2582
|
+
_compilable_global_rmsnorm_clip_(x, clip_at)
|
2583
|
+
return x
|
2584
|
+
|
2585
|
+
|
2586
|
+
@decorator_knowngood
|
2587
|
+
def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
|
2588
|
+
norm = 0
|
2589
|
+
for i in x:
|
2590
|
+
norm += promote(i).square().sum()
|
2591
|
+
norm = norm**0.5
|
2592
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2593
|
+
stochastic_multiply_(x, scalar)
|
2203
2594
|
|
2204
2595
|
|
2205
2596
|
def global_l2norm_clip(x, clip_at: float = 1.0):
|
2206
2597
|
x = list_guard(x)
|
2207
|
-
|
2598
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2599
|
+
_compilable_global_l2norm_clip_(x, clip_at)
|
2600
|
+
return x
|
2208
2601
|
|
2209
2602
|
|
2210
2603
|
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
2211
2604
|
x = list_guard(x)
|
2212
|
-
|
2605
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2606
|
+
return x
|
2213
2607
|
|
2214
2608
|
|
2215
2609
|
@decorator_knowngood
|
@@ -2284,17 +2678,6 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
|
2284
2678
|
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
2285
2679
|
|
2286
2680
|
|
2287
|
-
@decorator_knowngood
|
2288
|
-
def _compilable_weight_decay_to_init_(p, init, weight_decay):
|
2289
|
-
_lerp(p, promote(init), 1 - weight_decay)
|
2290
|
-
|
2291
|
-
|
2292
|
-
def weight_decay_to_init_(p, init, weight_decay):
|
2293
|
-
p, init = list_guard(p, init)
|
2294
|
-
weight_decay = scalar_guard(weight_decay, p[0])
|
2295
|
-
_compilable_weight_decay_to_ema_(p, init, weight_decay)
|
2296
|
-
|
2297
|
-
|
2298
2681
|
@decorator_knowngood
|
2299
2682
|
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
2300
2683
|
ema32 = _lerp(ema, p, ema_decay)
|
@@ -2416,7 +2799,7 @@ def precond_grad_cached_(
|
|
2416
2799
|
md = min_dtype(list(cached_q) + [ea])
|
2417
2800
|
args = [q.to(md) for q in cached_q]
|
2418
2801
|
args = args + [ea.to(md)]
|
2419
|
-
expr = cached_precond_grad_expr(ndim_tuple(cached_q),
|
2802
|
+
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
|
2420
2803
|
new = compiled_einsum(expr, *args)
|
2421
2804
|
if cast:
|
2422
2805
|
return new.to(ea.dtype)
|
@@ -2433,7 +2816,7 @@ def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, c
|
|
2433
2816
|
|
2434
2817
|
|
2435
2818
|
def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2436
|
-
lr = scalar_guard(lr, param[0])
|
2819
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2437
2820
|
_compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
|
2438
2821
|
|
2439
2822
|
|
@@ -2502,7 +2885,7 @@ def fused_psgd_precond_grad(
|
|
2502
2885
|
store_triu_as_line: bool = False,
|
2503
2886
|
symmetric_output: bool = False,
|
2504
2887
|
):
|
2505
|
-
lr = scalar_guard(lr, param[0])
|
2888
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2506
2889
|
_compilable_fused_psgd_precond_grad(
|
2507
2890
|
ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
|
2508
2891
|
)
|