heavyball 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +496 -100
- heavyball/chainable.py +444 -155
- heavyball/utils.py +326 -143
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info}/METADATA +11 -4
- heavyball-1.7.0.dist-info/RECORD +8 -0
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info}/WHEEL +1 -1
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info/licenses}/LICENSE +1 -1
- heavyball-1.6.2.dist-info/RECORD +0 -8
- {heavyball-1.6.2.dist-info → heavyball-1.7.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -4,7 +4,7 @@ import math
|
|
4
4
|
import random
|
5
5
|
import string
|
6
6
|
import warnings
|
7
|
-
from typing import List, Optional, Tuple,
|
7
|
+
from typing import Callable, List, Optional, Tuple, Union
|
8
8
|
from unittest.mock import patch
|
9
9
|
|
10
10
|
import numpy as np
|
@@ -17,25 +17,18 @@ from torch.utils._pytree import tree_map
|
|
17
17
|
|
18
18
|
config.cache_size_limit = 2 ** 16
|
19
19
|
|
20
|
-
np.warnings = warnings
|
21
|
-
|
22
20
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
21
|
dynamic = False
|
24
22
|
compile_mode_recommended_to_none = None
|
25
|
-
zeroth_power_mode =
|
23
|
+
zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
|
26
24
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
25
|
|
28
|
-
base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
|
29
|
-
'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
|
30
|
-
'weight_decay': 1e-4}
|
31
|
-
|
32
26
|
|
33
27
|
def decorator(func):
|
34
28
|
compiled = None
|
35
29
|
|
36
30
|
@functools.wraps(func)
|
37
31
|
def _fn(*args, **kwargs):
|
38
|
-
disable = compile_mode_recommended_to_none is None
|
39
32
|
if is_compiling() or compile_mode_recommended_to_none is None:
|
40
33
|
return func(*args, **kwargs)
|
41
34
|
nonlocal compiled
|
@@ -66,7 +59,7 @@ einsum_base = string.ascii_lowercase
|
|
66
59
|
|
67
60
|
@decorator_knowngood
|
68
61
|
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
|
69
|
-
|
62
|
+
beta1: Tensor, decay: float, grad: List[Tensor], caution, ):
|
70
63
|
for op, oz, u_, g_ in zip(p, z, update, grad):
|
71
64
|
u_ = u_.view_as(op)
|
72
65
|
p_, z_, u_ = map(promote, (op, oz, u_))
|
@@ -82,8 +75,8 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
|
|
82
75
|
|
83
76
|
|
84
77
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
85
|
-
|
86
|
-
|
78
|
+
z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0, step: int = 0,
|
79
|
+
decay: float = 0.0, ):
|
87
80
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
88
81
|
weight_sum = weight_sum + weight
|
89
82
|
|
@@ -165,7 +158,7 @@ def eps_sqrt(item, eps):
|
|
165
158
|
|
166
159
|
@decorator_knowngood
|
167
160
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
168
|
-
|
161
|
+
out: List[Optional[Tensor]]):
|
169
162
|
g32 = promote(grad)
|
170
163
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
171
164
|
|
@@ -227,7 +220,7 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
|
|
227
220
|
|
228
221
|
|
229
222
|
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
230
|
-
|
223
|
+
minimum: float = 1e-3, eps: float = 1e-8):
|
231
224
|
if clip_val <= 0:
|
232
225
|
return gradients
|
233
226
|
parameters, gradients = list_guard(parameters, gradients)
|
@@ -253,23 +246,22 @@ def clean():
|
|
253
246
|
|
254
247
|
|
255
248
|
def _ignore_warning(msg):
|
256
|
-
warnings.filterwarnings(
|
249
|
+
warnings.filterwarnings("ignore", f".*{msg}.*")
|
257
250
|
|
258
251
|
|
259
|
-
def set_torch(benchmark_limit: int = 32):
|
252
|
+
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
260
253
|
cudnn.benchmark = True
|
261
254
|
cudnn.deterministic = False
|
262
255
|
cudnn.benchmark_limit = benchmark_limit
|
263
256
|
torch.use_deterministic_algorithms(False)
|
264
257
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
265
|
-
opt_einsum.
|
266
|
-
opt_einsum.strategy = "auto"
|
258
|
+
opt_einsum.set_flags(True, einsum_strategy)
|
267
259
|
|
268
260
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
269
261
|
_ignore_warning(
|
270
|
-
|
262
|
+
"Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak")
|
271
263
|
_ignore_warning(
|
272
|
-
|
264
|
+
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak")
|
273
265
|
|
274
266
|
|
275
267
|
@decorator
|
@@ -277,7 +269,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
277
269
|
assert len(G.shape) == 2
|
278
270
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
279
271
|
X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
|
280
|
-
X /=
|
272
|
+
X /= X.norm() + eps # ensure top singular value <= 1
|
281
273
|
if G.size(0) > G.size(1):
|
282
274
|
X = X.T
|
283
275
|
for _ in range(steps):
|
@@ -290,10 +282,10 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
290
282
|
|
291
283
|
|
292
284
|
def ortho(x):
|
293
|
-
if zeroth_power_mode ==
|
285
|
+
if zeroth_power_mode == "qr":
|
294
286
|
return torch.linalg.qr(x).Q
|
295
|
-
if zeroth_power_mode ==
|
296
|
-
u,
|
287
|
+
if zeroth_power_mode == "svd":
|
288
|
+
u, _s, v = torch.linalg.svd(x)
|
297
289
|
return u @ v.T
|
298
290
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
299
291
|
|
@@ -351,12 +343,12 @@ def _compilable_grafting(magnitude, direction):
|
|
351
343
|
|
352
344
|
@decorator_knowngood
|
353
345
|
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
354
|
-
if mode ==
|
346
|
+
if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
|
355
347
|
y = zeropower_via_newtonschulz5(x, 5)
|
356
|
-
elif mode ==
|
348
|
+
elif mode == "qr":
|
357
349
|
y = torch.linalg.qr(promote(x)).Q
|
358
|
-
elif mode ==
|
359
|
-
u,
|
350
|
+
elif mode == "svd":
|
351
|
+
u, _s, v = torch.linalg.svd(promote(x))
|
360
352
|
y = u @ v.T
|
361
353
|
else:
|
362
354
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
@@ -403,7 +395,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
403
395
|
q_old = promote(q.data)
|
404
396
|
|
405
397
|
tmp = m @ q_old
|
406
|
-
est_eig = torch.einsum(
|
398
|
+
est_eig = torch.einsum("ij,ij->j", q_old, tmp)
|
407
399
|
sort_idx = torch.argsort(est_eig, descending=True)
|
408
400
|
|
409
401
|
tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
|
@@ -415,19 +407,19 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
415
407
|
return
|
416
408
|
|
417
409
|
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
418
|
-
in_str = einsum_base[:exp_avg.dim()]
|
419
|
-
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
410
|
+
in_str = einsum_base[: exp_avg.dim()]
|
411
|
+
out_str = einsum_base[exp_avg.dim(): 2 * exp_avg.dim()]
|
420
412
|
|
421
413
|
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
422
414
|
if not from_shampoo:
|
423
415
|
return
|
424
416
|
|
425
|
-
to_shampoo =
|
426
|
-
out_str =
|
417
|
+
to_shampoo = ",".join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
|
418
|
+
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
427
419
|
|
428
|
-
subscripts = f
|
420
|
+
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
429
421
|
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
|
430
|
-
|
422
|
+
*[q for q in new_qs if q is not None])
|
431
423
|
copy_stochastic_(exp_avg, exp_avg_new)
|
432
424
|
|
433
425
|
for q, q_new in zip(Q, new_qs):
|
@@ -453,11 +445,11 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
453
445
|
while True:
|
454
446
|
try:
|
455
447
|
eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
|
456
|
-
|
448
|
+
_eigval, eigvec = torch.linalg.eigh(m + eps * eye)
|
457
449
|
eigvec = eigvec.to(device=device, dtype=dtype)
|
458
450
|
break
|
459
451
|
except torch.OutOfMemoryError:
|
460
|
-
if m.device.type ==
|
452
|
+
if m.device.type == "cpu":
|
461
453
|
raise
|
462
454
|
else:
|
463
455
|
m = m.cpu()
|
@@ -489,21 +481,21 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
489
481
|
|
490
482
|
def get_beta1(group):
|
491
483
|
beta = None
|
492
|
-
if
|
493
|
-
beta = group[
|
494
|
-
if beta is None and
|
495
|
-
beta = group[
|
484
|
+
if "beta" in group:
|
485
|
+
beta = group["beta"]
|
486
|
+
if beta is None and "betas" in group:
|
487
|
+
beta = group["betas"][0]
|
496
488
|
if beta is None:
|
497
489
|
raise ValueError("Beta not found in group.")
|
498
490
|
return beta
|
499
491
|
|
500
492
|
|
501
493
|
def get_beta2(group):
|
502
|
-
if
|
494
|
+
if "palm" in group and group["palm"] is True and "beta2_scale" in group:
|
503
495
|
step = max(group.get("step", 1), 1)
|
504
|
-
return 1 - step ** -group[
|
505
|
-
if
|
506
|
-
return group[
|
496
|
+
return 1 - step ** -group["beta2_scale"]
|
497
|
+
if "betas" in group:
|
498
|
+
return group["betas"][1]
|
507
499
|
raise ValueError("Beta2 not found in group.")
|
508
500
|
|
509
501
|
|
@@ -580,9 +572,9 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
580
572
|
if not isinstance(m, Tensor):
|
581
573
|
continue
|
582
574
|
b = einsum_base[idx]
|
583
|
-
g0 = einsum_base[:grad.dim()]
|
575
|
+
g0 = einsum_base[: grad.dim()]
|
584
576
|
g1 = g0.replace(b, b.upper())
|
585
|
-
outer_product = torch.einsum(f
|
577
|
+
outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
|
586
578
|
stochastic_lerp_(m, outer_product, 1 - beta)
|
587
579
|
|
588
580
|
|
@@ -623,19 +615,19 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
|
623
615
|
"""
|
624
616
|
Initializes the preconditioner matrices (L and R in the paper).
|
625
617
|
"""
|
626
|
-
state[
|
618
|
+
state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
627
619
|
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
628
620
|
for sh in grad.shape:
|
629
621
|
if sh > max_precond_dim or sh == 1:
|
630
622
|
# via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
|
631
|
-
state[
|
623
|
+
state["GG"].append(None)
|
632
624
|
else:
|
633
|
-
state[
|
625
|
+
state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
634
626
|
else:
|
635
|
-
state[
|
627
|
+
state["GG"].append(None)
|
636
628
|
|
637
|
-
update_ggt(grad, state[
|
638
|
-
state[
|
629
|
+
update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0)
|
630
|
+
state["Q"] = get_orthogonal_matrix(state["GG"])
|
639
631
|
|
640
632
|
|
641
633
|
@decorator
|
@@ -646,11 +638,11 @@ def project(grad, Q, back: bool):
|
|
646
638
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
647
639
|
:return:
|
648
640
|
"""
|
649
|
-
param = einsum_base[:grad.dim()]
|
650
|
-
preconditioners = ",".join([(g + g.upper())[
|
641
|
+
param = einsum_base[: grad.dim()]
|
642
|
+
preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
651
643
|
if preconditioners:
|
652
|
-
out =
|
653
|
-
out = torch.einsum(f
|
644
|
+
out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
|
645
|
+
out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
|
654
646
|
grad = out.to(grad.dtype)
|
655
647
|
return grad
|
656
648
|
|
@@ -667,12 +659,12 @@ def modify_closure(closure):
|
|
667
659
|
"""
|
668
660
|
|
669
661
|
def patched_backward(self, *args, **kwargs):
|
670
|
-
kwargs[
|
662
|
+
kwargs["create_graph"] = True
|
671
663
|
return original_backward(self, *args, **kwargs)
|
672
664
|
|
673
665
|
original_backward = torch.Tensor.backward
|
674
666
|
|
675
|
-
with patch.object(torch.Tensor,
|
667
|
+
with patch.object(torch.Tensor, "backward", patched_backward):
|
676
668
|
return closure()
|
677
669
|
|
678
670
|
|
@@ -683,6 +675,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
683
675
|
The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
|
684
676
|
Further notice that both methods have different numerics outputs
|
685
677
|
"""
|
678
|
+
|
686
679
|
ema_decay: float = 0.001
|
687
680
|
compile_step: bool = False
|
688
681
|
hessian_approx: bool = False
|
@@ -691,10 +684,10 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
691
684
|
finite_differences: bool = False
|
692
685
|
|
693
686
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
694
|
-
super().__init__(params, {**defaults,
|
687
|
+
super().__init__(params, {**defaults, "foreach": foreach})
|
695
688
|
self.use_ema = use_ema
|
696
689
|
self.mapping = {}
|
697
|
-
self._inner_group = {
|
690
|
+
self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
|
698
691
|
self._precond_rng = random.Random(0x12312)
|
699
692
|
self._is_preconditioning = None
|
700
693
|
|
@@ -710,24 +703,24 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
710
703
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
711
704
|
for p, g in zip(p_list, g_list):
|
712
705
|
state = self.state_(p)
|
713
|
-
if
|
714
|
-
state[
|
715
|
-
old_gs = [self.state_(p)[
|
706
|
+
if "mars_old_grad" not in state:
|
707
|
+
state["mars_old_grad"] = torch.zeros_like(g)
|
708
|
+
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
716
709
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
717
710
|
|
718
711
|
def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
|
719
|
-
|
712
|
+
beta1: float = -1.0):
|
720
713
|
for p in group["params"]:
|
721
714
|
if p in self.mapping:
|
722
715
|
p_views = self.mapping[p]
|
723
716
|
else:
|
724
717
|
self.mapping[p] = p_views = merge_group(group, p)
|
725
718
|
|
726
|
-
grad = getattr(p,
|
719
|
+
grad = getattr(p, "grad", None)
|
727
720
|
p.grad = None
|
728
721
|
|
729
722
|
if grad is None:
|
730
|
-
grad = [getattr(pv,
|
723
|
+
grad = [getattr(pv, "grad", None) for pv in p_views]
|
731
724
|
else:
|
732
725
|
grad = merge_group(group, grad)
|
733
726
|
|
@@ -736,8 +729,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
736
729
|
continue
|
737
730
|
if should_promote:
|
738
731
|
g = promote(g)
|
739
|
-
if beta1 >= 0 and group.get(
|
740
|
-
self.mars_correct_list(group, [pv], [g], group[
|
732
|
+
if beta1 >= 0 and group.get("mars", False):
|
733
|
+
self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
|
741
734
|
yield pv, g
|
742
735
|
|
743
736
|
def state_size(self) -> int:
|
@@ -759,46 +752,46 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
759
752
|
def ema_update(self):
|
760
753
|
with torch.no_grad():
|
761
754
|
for group in self.param_groups:
|
762
|
-
active_p = [p for p in group[
|
755
|
+
active_p = [p for p in group["params"]]
|
763
756
|
|
764
757
|
if not active_p:
|
765
758
|
return
|
766
759
|
|
767
|
-
k = group[
|
760
|
+
k = group["ema_step"] = group.get("ema_step", -1) + 1
|
768
761
|
|
769
762
|
for p in active_p:
|
770
|
-
if
|
771
|
-
self.state_(p)[
|
763
|
+
if "param_ema" not in self.state_(p):
|
764
|
+
self.state_(p)["param_ema"] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
772
765
|
|
773
|
-
y, param_ema = zip(*[(p.data, self.state_(p)[
|
766
|
+
y, param_ema = zip(*[(p.data, self.state_(p)["param_ema"]) for p in active_p])
|
774
767
|
torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
|
775
768
|
|
776
769
|
def copy_emas_to_params(self):
|
777
770
|
with torch.no_grad():
|
778
771
|
for group in self.param_groups:
|
779
|
-
active_p = [p for p in group[
|
772
|
+
active_p = [p for p in group["params"]]
|
780
773
|
|
781
774
|
if not active_p:
|
782
775
|
return
|
783
776
|
|
784
777
|
for p in active_p:
|
785
|
-
if
|
778
|
+
if "param_ema" in self.state_(p):
|
786
779
|
p_clone = p.data.clone()
|
787
|
-
set_(p.data, self.state_(p)[
|
788
|
-
set_(self.state_(p)[
|
780
|
+
set_(p.data, self.state_(p)["param_ema"])
|
781
|
+
set_(self.state_(p)["param_ema"], p_clone)
|
789
782
|
|
790
783
|
def copy_params_to_emas(self):
|
791
784
|
with torch.no_grad():
|
792
785
|
for group in self.param_groups:
|
793
|
-
active_p = [p for p in group[
|
786
|
+
active_p = [p for p in group["params"]]
|
794
787
|
|
795
788
|
if not active_p:
|
796
789
|
return
|
797
790
|
|
798
791
|
for p in active_p:
|
799
|
-
if
|
800
|
-
ema_clone = self.state_(p)[
|
801
|
-
set_(self.state_(p)[
|
792
|
+
if "param_ema" in self.state_(p):
|
793
|
+
ema_clone = self.state_(p)["param_ema"].data.clone()
|
794
|
+
set_(self.state_(p)["param_ema"], p.data)
|
802
795
|
set_(p.data, ema_clone)
|
803
796
|
|
804
797
|
def _handle_closure(self, closure):
|
@@ -824,7 +817,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
824
817
|
grads.append(g)
|
825
818
|
p.vector = torch.randn_like(p)
|
826
819
|
p.orig = p.data.clone()
|
827
|
-
|
820
|
+
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
|
821
|
+
stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
|
828
822
|
else:
|
829
823
|
with torch.enable_grad():
|
830
824
|
loss = modify_closure(closure)
|
@@ -833,6 +827,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
833
827
|
with torch.enable_grad():
|
834
828
|
closure()
|
835
829
|
|
830
|
+
# we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
|
831
|
+
# this costs more memory, but the imprecision seems too severe to use the other method
|
836
832
|
for group in self.param_groups:
|
837
833
|
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
838
834
|
p.grad = grads.pop(0)
|
@@ -845,7 +841,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
845
841
|
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
846
842
|
p.grad = g
|
847
843
|
params, grads = zip(*[x for group in self.param_groups for x in
|
848
|
-
|
844
|
+
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
849
845
|
vs = [torch.randn_like(p) for p in params]
|
850
846
|
with torch.enable_grad():
|
851
847
|
hvs = torch.autograd.grad(grads, params, vs)
|
@@ -867,7 +863,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
867
863
|
# we assume that parameters are constant and that there are no excessive recompiles
|
868
864
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
869
865
|
for group in self.param_groups:
|
870
|
-
group[
|
866
|
+
group["is_preconditioning"] = self._is_preconditioning
|
871
867
|
self._step(group)
|
872
868
|
if self.use_ema:
|
873
869
|
self.ema_update()
|
@@ -892,7 +888,7 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
|
892
888
|
|
893
889
|
@decorator_knowngood
|
894
890
|
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
895
|
-
|
891
|
+
step: Tensor, eps: Tensor, ):
|
896
892
|
beta1 = beta_debias(beta1, step)
|
897
893
|
beta2 = beta_debias(beta2, step)
|
898
894
|
|
@@ -904,7 +900,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
904
900
|
|
905
901
|
|
906
902
|
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
907
|
-
|
903
|
+
eps: float = 1e-8, ):
|
908
904
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
909
905
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
910
906
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -913,8 +909,8 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
|
|
913
909
|
|
914
910
|
@decorator_knowngood
|
915
911
|
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
916
|
-
|
917
|
-
|
912
|
+
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
|
913
|
+
caution: bool, ):
|
918
914
|
beta1 = beta_debias(beta1, step)
|
919
915
|
beta2 = beta_debias(beta2, step)
|
920
916
|
|
@@ -926,8 +922,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
926
922
|
|
927
923
|
|
928
924
|
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
929
|
-
|
930
|
-
|
925
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
|
926
|
+
caution: bool, ):
|
931
927
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
932
928
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
933
929
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
@@ -935,7 +931,7 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
|
|
935
931
|
|
936
932
|
@decorator_knowngood
|
937
933
|
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
938
|
-
|
934
|
+
beta2: Tensor, step: Tensor, eps: Tensor, ):
|
939
935
|
beta1 = beta_debias(beta1, step)
|
940
936
|
beta2 = beta_debias(beta2, step)
|
941
937
|
|
@@ -947,7 +943,7 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
947
943
|
|
948
944
|
|
949
945
|
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
950
|
-
|
946
|
+
eps: float = 1e-8, ):
|
951
947
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
952
948
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
953
949
|
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -956,8 +952,8 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
956
952
|
|
957
953
|
@decorator_knowngood
|
958
954
|
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
959
|
-
|
960
|
-
|
955
|
+
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor, caution: bool,
|
956
|
+
eps: Tensor, ):
|
961
957
|
beta1 = beta_debias(beta1, step)
|
962
958
|
beta2 = beta_debias(beta2, step)
|
963
959
|
|
@@ -969,8 +965,8 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
969
965
|
|
970
966
|
|
971
967
|
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
972
|
-
|
973
|
-
|
968
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
|
969
|
+
eps: float = 1e-8, ):
|
974
970
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
975
971
|
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
976
972
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
@@ -978,7 +974,7 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
|
|
978
974
|
|
979
975
|
@decorator_knowngood
|
980
976
|
def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
981
|
-
u32, g32, exp_avg_sq32
|
977
|
+
u32, g32, exp_avg_sq32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq]]
|
982
978
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
983
979
|
|
984
980
|
beta1 = beta_debias(beta1, step)
|
@@ -997,7 +993,7 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
|
|
997
993
|
|
998
994
|
@decorator_knowngood
|
999
995
|
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
|
1000
|
-
g32,
|
996
|
+
g32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
|
1001
997
|
update = [e.clone() for e in exp_avg]
|
1002
998
|
|
1003
999
|
beta1 = beta_debias(beta1, step)
|
@@ -1045,7 +1041,7 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
1045
1041
|
|
1046
1042
|
@decorator_knowngood
|
1047
1043
|
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
|
1048
|
-
|
1044
|
+
g: List[Optional[Tensor]]):
|
1049
1045
|
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
1050
1046
|
u_ = promote(u_.view_as(p_))
|
1051
1047
|
p32_ = promote(p_)
|
@@ -1056,7 +1052,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
|
|
1056
1052
|
|
1057
1053
|
|
1058
1054
|
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
|
1059
|
-
|
1055
|
+
grad: List[Tensor] = None):
|
1060
1056
|
param, update, grad = list_guard(param, update, grad)
|
1061
1057
|
lr = scalar_guard(lr, param[0])
|
1062
1058
|
if not caution:
|
@@ -1064,38 +1060,70 @@ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: f
|
|
1064
1060
|
_compilable_update_(param, update, decay, lr, caution, grad)
|
1065
1061
|
|
1066
1062
|
|
1067
|
-
def precond_schedule(step, precond_scheduler
|
1063
|
+
def precond_schedule(step, precond_scheduler):
|
1068
1064
|
precond_prob = max(step, 1) ** precond_scheduler[0]
|
1069
1065
|
precond_prob = math.log10(precond_prob)
|
1070
1066
|
precond_prob = precond_prob ** precond_scheduler[1] + 1
|
1071
|
-
|
1072
|
-
update_precond = rng.random() < precond_prob
|
1073
|
-
return update_precond
|
1067
|
+
return 1 / precond_prob
|
1074
1068
|
|
1075
1069
|
|
1076
1070
|
def get_soap_precond_schedule(precond_scheduler):
|
1077
|
-
|
1078
|
-
|
1079
|
-
def _inner(step):
|
1080
|
-
return precond_schedule(step, precond_scheduler, rng)
|
1081
|
-
|
1082
|
-
return _inner
|
1071
|
+
return functools.partial(precond_schedule, precond_scheduler=precond_scheduler)
|
1083
1072
|
|
1084
1073
|
|
1085
1074
|
def _max_idx(x: List[int]):
|
1086
1075
|
return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
|
1087
1076
|
|
1088
1077
|
|
1089
|
-
|
1090
|
-
|
1078
|
+
@decorator_knowngood
|
1079
|
+
def mean_root(x: torch.Tensor, pow: float):
|
1080
|
+
return stochastic_round_(x, x.float().pow(pow).mean().pow(-1 / pow / 2))
|
1081
|
+
|
1082
|
+
|
1083
|
+
@decorator_knowngood
|
1084
|
+
def divided_root(x, y, pow0, pow1):
|
1085
|
+
mean_x = x.float().pow(pow0).mean().pow(1 / pow0 / 2)
|
1086
|
+
mean_y = y.float().pow(pow1).mean().pow(-1 / pow1 / 2)
|
1087
|
+
return stochastic_round_(x, mean_x * mean_y) # multiply here, as we already divide in pow -1
|
1088
|
+
|
1089
|
+
|
1090
|
+
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector):
|
1091
|
+
if scale is not None:
|
1092
|
+
warn_once(
|
1093
|
+
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics.")
|
1094
|
+
if scale_scale is not None:
|
1095
|
+
warn_once(
|
1096
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale.")
|
1097
|
+
return scale
|
1098
|
+
if hessian_vector is None:
|
1099
|
+
return mean_root(grad, 4) * scale_scale
|
1100
|
+
return divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1101
|
+
|
1102
|
+
|
1103
|
+
def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
|
1104
|
+
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1105
|
+
U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
|
1106
|
+
V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
|
1107
|
+
d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
|
1108
|
+
return U, V, d
|
1109
|
+
|
1110
|
+
|
1111
|
+
def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector,
|
1112
|
+
dtype=None):
|
1113
|
+
"""
|
1114
|
+
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
1091
1115
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
1116
|
+
|
1117
|
+
precond init scale computation from
|
1118
|
+
https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
|
1092
1119
|
"""
|
1120
|
+
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1093
1121
|
letters = string.ascii_lowercase + string.ascii_uppercase
|
1094
|
-
dtype = dtype if dtype is not None else
|
1095
|
-
shape =
|
1122
|
+
dtype = dtype if dtype is not None else grad.dtype
|
1123
|
+
shape = grad.shape
|
1096
1124
|
|
1097
1125
|
if len(shape) == 0: # scalar
|
1098
|
-
Q = [scale * torch.ones_like(
|
1126
|
+
Q = [scale * torch.ones_like(grad, dtype=dtype)]
|
1099
1127
|
exprA = ",->"
|
1100
1128
|
exprGs = [",->"]
|
1101
1129
|
exprP = ",,->"
|
@@ -1103,7 +1131,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1103
1131
|
|
1104
1132
|
# Tensor
|
1105
1133
|
if len(shape) > 13:
|
1106
|
-
raise ValueError(f"Got tensor with dim {len(
|
1134
|
+
raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
|
1107
1135
|
|
1108
1136
|
scale = scale ** (1 / len(shape))
|
1109
1137
|
|
@@ -1129,7 +1157,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1129
1157
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
1130
1158
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
1131
1159
|
# use diagonal matrix as preconditioner for this dim
|
1132
|
-
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=
|
1160
|
+
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
|
1133
1161
|
|
1134
1162
|
piece1A.append(letters[i])
|
1135
1163
|
piece2A = piece2A + letters[i]
|
@@ -1143,13 +1171,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1143
1171
|
piece4P = piece4P + letters[i + 13]
|
1144
1172
|
else:
|
1145
1173
|
# use triangular matrix as preconditioner for this dim
|
1146
|
-
Q.append(scale * torch.eye(size, dtype=dtype, device=
|
1174
|
+
Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
|
1147
1175
|
piece1A.append(letters[i] + letters[i + 13])
|
1148
1176
|
piece2A = piece2A + letters[i + 13]
|
1149
1177
|
piece3A = piece3A + letters[i]
|
1150
1178
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1151
1179
|
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1152
|
-
subscripts =
|
1180
|
+
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
|
1153
1181
|
exprGs.append(subscripts)
|
1154
1182
|
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1155
1183
|
piece1P.append(a + b)
|
@@ -1158,7 +1186,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1158
1186
|
piece4P = piece4P + b
|
1159
1187
|
|
1160
1188
|
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
1161
|
-
exprP =
|
1189
|
+
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
1162
1190
|
return [Q, (exprA, tuple(exprGs), exprP)]
|
1163
1191
|
|
1164
1192
|
|
@@ -1170,17 +1198,171 @@ def psgd_balance_Q(Q_in):
|
|
1170
1198
|
torch._foreach_mul_(Q_in, list(norms))
|
1171
1199
|
|
1172
1200
|
|
1201
|
+
@decorator
|
1202
|
+
def psgd_balance_lra(U: Tensor, V: Tensor):
|
1203
|
+
u_norm = promote(torch.linalg.vector_norm(U))
|
1204
|
+
v_norm = promote(torch.linalg.vector_norm(V))
|
1205
|
+
scale = (u_norm / v_norm) ** 0.5
|
1206
|
+
U.div_(scale)
|
1207
|
+
V.mul_(scale)
|
1208
|
+
|
1209
|
+
|
1210
|
+
@decorator
|
1211
|
+
def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
1212
|
+
dtype = min_dtype([U, V, x])
|
1213
|
+
return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1214
|
+
|
1215
|
+
|
1216
|
+
def update_lra_precond_(U: List[Tensor], V: List[Tensor], d: List[Tensor], vector: Tensor, hessian_vector: Tensor,
|
1217
|
+
eps: float, step: float, delayed: bool, ):
|
1218
|
+
"""
|
1219
|
+
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1220
|
+
"""
|
1221
|
+
U_orig, V_orig, d_orig = U, V, d
|
1222
|
+
|
1223
|
+
U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
|
1224
|
+
|
1225
|
+
dtype = min_dtype([U, V, vector, hessian_vector])
|
1226
|
+
U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
|
1227
|
+
|
1228
|
+
eps = scalar_guard(eps, vector)
|
1229
|
+
|
1230
|
+
Qh = low_rank_mm(U, V, d * hessian_vector)
|
1231
|
+
Ph = d * low_rank_mm(V, U, Qh)
|
1232
|
+
rank = U.size(1)
|
1233
|
+
|
1234
|
+
VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
|
1235
|
+
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
1236
|
+
IpVtU = I + VtU
|
1237
|
+
invQtv = vector / d
|
1238
|
+
|
1239
|
+
# LU factorization to reuse computation
|
1240
|
+
try:
|
1241
|
+
LU, pivots = torch.linalg.lu_factor(IpVtU)
|
1242
|
+
except RuntimeError:
|
1243
|
+
# Error:
|
1244
|
+
# U[2,2] is zero and using it on lu_solve would result in a division by zero.
|
1245
|
+
# If you still want to perform the factorization, consider calling
|
1246
|
+
# linalg.lu(A, pivot) or linalg.lu_factor_ex(A, pivot)
|
1247
|
+
# ---
|
1248
|
+
# So, we skip this step and reattempt on the next one
|
1249
|
+
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1250
|
+
|
1251
|
+
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
1252
|
+
invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
1253
|
+
invPv = invPv / d
|
1254
|
+
|
1255
|
+
nablaD = Ph * hessian_vector - vector * invPv
|
1256
|
+
divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
|
1257
|
+
divisor = divisor.add(eps).sqrt().max()
|
1258
|
+
d_step = step / divisor
|
1259
|
+
|
1260
|
+
apply_flat_add(d_orig, d * nablaD, -d_step)
|
1261
|
+
|
1262
|
+
a, b = Qh, invQtv
|
1263
|
+
|
1264
|
+
precond_u = random.random() < 0.5 # update either U or V, not both at the same time
|
1265
|
+
precond = V if precond_u else U
|
1266
|
+
atV = torch.einsum("b,br->r", a, precond) # o == one
|
1267
|
+
btV = torch.einsum("b,br->r", b, precond)
|
1268
|
+
atVVt = torch.einsum("r,br->b", atV, precond)
|
1269
|
+
btVVt = torch.einsum("r,br->b", btV, precond)
|
1270
|
+
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
|
1271
|
+
if precond_u:
|
1272
|
+
a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
|
1273
|
+
b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
|
1274
|
+
else:
|
1275
|
+
a = a + torch.einsum("br,r->b", V, atV)
|
1276
|
+
b = b + torch.einsum("br,r->b", V, btV)
|
1277
|
+
a = torch.einsum("b,r->br", a, atV)
|
1278
|
+
b = torch.einsum("b,r->br", b, btV)
|
1279
|
+
apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
|
1280
|
+
|
1281
|
+
if not delayed:
|
1282
|
+
stochastic_add_([d], [d * nablaD], -d_step)
|
1283
|
+
stochastic_add_([U if precond_u else V], [b - a], precond_step)
|
1284
|
+
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1285
|
+
|
1286
|
+
|
1287
|
+
def lra_precond(U, V, d, g):
|
1288
|
+
"""
|
1289
|
+
As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
|
1290
|
+
"""
|
1291
|
+
g = low_rank_mm(U, V, d * g)
|
1292
|
+
return d * low_rank_mm(V, U, g)
|
1293
|
+
|
1294
|
+
|
1295
|
+
@decorator_knowngood
|
1296
|
+
def dampen_grad(g: Tensor, damp: float = 2 ** -13):
|
1297
|
+
# https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
|
1298
|
+
v = torch.randn_like(g)
|
1299
|
+
return v, g + damp * g.abs().mean() * v
|
1300
|
+
|
1301
|
+
|
1302
|
+
@decorator_knowngood
|
1303
|
+
def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
|
1304
|
+
update = lra_precond(U, V, d, update)
|
1305
|
+
start = 0
|
1306
|
+
update = update.flatten()
|
1307
|
+
for p in params:
|
1308
|
+
size = p.numel()
|
1309
|
+
copy_stochastic_(p, update[start: start + size].view_as(p))
|
1310
|
+
start += size
|
1311
|
+
|
1312
|
+
|
1313
|
+
@decorator_knowngood
|
1314
|
+
def apply_flat_update(params: List[Tensor], update: Tensor):
|
1315
|
+
start = 0
|
1316
|
+
update = update.flatten()
|
1317
|
+
for p in params:
|
1318
|
+
size = p.numel()
|
1319
|
+
copy_stochastic_(p, update[start: start + size].view_as(p))
|
1320
|
+
start += size
|
1321
|
+
|
1322
|
+
|
1323
|
+
@decorator_knowngood
|
1324
|
+
def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
1325
|
+
start = 0
|
1326
|
+
update = update.flatten()
|
1327
|
+
for p in params:
|
1328
|
+
size = p.numel()
|
1329
|
+
stochastic_add_([p], [update[start: start + size].view_as(p)], alpha)
|
1330
|
+
start += size
|
1331
|
+
|
1332
|
+
|
1333
|
+
@decorator_knowngood
|
1334
|
+
def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
1335
|
+
start = 0
|
1336
|
+
outputs = []
|
1337
|
+
update = update.flatten()
|
1338
|
+
for p in params:
|
1339
|
+
size = p.numel()
|
1340
|
+
outputs.append(update[start: start + size].view_as(p))
|
1341
|
+
start += size
|
1342
|
+
return outputs
|
1343
|
+
|
1344
|
+
|
1345
|
+
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1346
|
+
return torch.cat([i.flatten(0, -1 - remaining) for i in x], 0)
|
1347
|
+
|
1348
|
+
|
1349
|
+
def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
|
1350
|
+
vs = []
|
1351
|
+
gs = []
|
1352
|
+
for g_ in g:
|
1353
|
+
v, g = dampen_grad(g_, damp)
|
1354
|
+
vs.append(v)
|
1355
|
+
gs.append(g)
|
1356
|
+
return flatten(vs), flatten(gs)
|
1357
|
+
|
1358
|
+
|
1173
1359
|
def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
|
1174
|
-
eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
|
1175
|
-
eps *= G.norm() / G.numel()
|
1176
|
-
G = G + torch.randn_like(G) * eps
|
1177
|
-
md = min_dtype(Q + [G])
|
1178
|
-
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
1179
1360
|
order = G.dim()
|
1180
1361
|
if V is None:
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1362
|
+
V, G = dampen_grad(G)
|
1363
|
+
conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
|
1364
|
+
md = min_dtype(Q + [G])
|
1365
|
+
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
1184
1366
|
Q = [promote(q) for q in Q]
|
1185
1367
|
for i, q in enumerate(Q):
|
1186
1368
|
if q.dim() <= 1:
|
@@ -1195,12 +1377,12 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
|
|
1195
1377
|
|
1196
1378
|
def psgd_lb(A, max_abs):
|
1197
1379
|
A /= max_abs
|
1198
|
-
a0 = torch.einsum(
|
1380
|
+
a0 = torch.einsum("ij,ij->j", A, A)
|
1199
1381
|
i = torch.argmax(a0)
|
1200
1382
|
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1201
|
-
x = torch.einsum(
|
1383
|
+
x = torch.einsum("i,ij->j", x, A)
|
1202
1384
|
x /= x.norm()
|
1203
|
-
x = torch.einsum(
|
1385
|
+
x = torch.einsum("j,kj->k", x, A)
|
1204
1386
|
x = x.norm()
|
1205
1387
|
x *= max_abs
|
1206
1388
|
return x
|
@@ -1217,7 +1399,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
|
1217
1399
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1218
1400
|
term1, term2 = term1 - term2, term1 + term2
|
1219
1401
|
term1 *= precond_lr
|
1220
|
-
norm = term2.norm(float(
|
1402
|
+
norm = term2.norm(float("inf"))
|
1221
1403
|
if q.dim() < 2:
|
1222
1404
|
term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
|
1223
1405
|
else:
|
@@ -1245,7 +1427,7 @@ def l2_normalization_(x, clip_at: float = 1e-8):
|
|
1245
1427
|
return _compilable_l2_clip_(x, clip_at)
|
1246
1428
|
|
1247
1429
|
|
1248
|
-
def l2_clip_(x, clip_at: float = 1.):
|
1430
|
+
def l2_clip_(x, clip_at: float = 1.0):
|
1249
1431
|
x = list_guard(x)
|
1250
1432
|
return _compilable_l2_clip_(x, clip_at)
|
1251
1433
|
|
@@ -1438,11 +1620,11 @@ def warn_once(msg):
|
|
1438
1620
|
|
1439
1621
|
|
1440
1622
|
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1441
|
-
|
1442
|
-
group[f
|
1623
|
+
name: str = "cumulative_prob"):
|
1624
|
+
group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
|
1443
1625
|
if not isinstance(prob, float):
|
1444
|
-
prob = prob(group[f
|
1445
|
-
if group[
|
1626
|
+
prob = prob(group[f"{name}_prob_step"])
|
1627
|
+
if group["stochastic_schedule"]:
|
1446
1628
|
return rng.random() < prob
|
1447
1629
|
cumulative_prob = group.get(name, 0)
|
1448
1630
|
group[name] = cumulative_prob + prob
|
@@ -1451,7 +1633,7 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1451
1633
|
|
1452
1634
|
@decorator_knowngood
|
1453
1635
|
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
|
1454
|
-
|
1636
|
+
cast: bool = True):
|
1455
1637
|
if caution:
|
1456
1638
|
ea = _compilable_cautioning(grad, ea)
|
1457
1639
|
md = min_dtype(list(cached_q) + [ea])
|
@@ -1564,15 +1746,16 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
|
|
1564
1746
|
|
1565
1747
|
|
1566
1748
|
def merge_group(group, *tensors):
|
1567
|
-
if not group.get(
|
1749
|
+
if not group.get("merge_dims", False):
|
1568
1750
|
return tensors
|
1569
1751
|
if isinstance(tensors[0], list):
|
1570
1752
|
return [merge_group(group, *t) for t in tensors]
|
1571
1753
|
|
1572
1754
|
out = []
|
1573
1755
|
for t in tensors:
|
1574
|
-
append_or_extend(out,
|
1575
|
-
|
1756
|
+
append_or_extend(out,
|
1757
|
+
dim_merger(t, group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
|
1758
|
+
group.get("split", False), ), )
|
1576
1759
|
return out
|
1577
1760
|
|
1578
1761
|
|
@@ -1599,7 +1782,7 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1599
1782
|
o = optimizer(parameters, *args, **kwargs)
|
1600
1783
|
step_fn = o.step
|
1601
1784
|
o.step = functools.partial(warn_once,
|
1602
|
-
|
1785
|
+
msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
|
1603
1786
|
|
1604
1787
|
def _step(p: Tensor):
|
1605
1788
|
seen_params.add(p)
|