heavyball 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +20 -1
- heavyball/chainable.py +50 -8
- heavyball/optimizations/__init__.py +38 -0
- heavyball/optimizations/integrator.py +169 -0
- heavyball/optimizations/optimizations.py +329 -0
- heavyball/utils.py +518 -162
- {heavyball-1.7.0.dist-info → heavyball-1.7.1.dist-info}/METADATA +1 -1
- heavyball-1.7.1.dist-info/RECORD +11 -0
- heavyball-1.7.0.dist-info/RECORD +0 -8
- {heavyball-1.7.0.dist-info → heavyball-1.7.1.dist-info}/WHEEL +0 -0
- {heavyball-1.7.0.dist-info → heavyball-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.0.dist-info → heavyball-1.7.1.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
+
import contextlib
|
1
2
|
import functools
|
2
3
|
import gc
|
4
|
+
import inspect
|
3
5
|
import math
|
4
6
|
import random
|
7
|
+
import re
|
5
8
|
import string
|
6
9
|
import warnings
|
7
10
|
from typing import Callable, List, Optional, Tuple, Union
|
8
|
-
from unittest.mock import patch
|
9
11
|
|
10
12
|
import numpy as np
|
11
13
|
import torch
|
@@ -15,13 +17,22 @@ from torch._dynamo.exc import TorchDynamoException
|
|
15
17
|
from torch.backends import cudnn, opt_einsum
|
16
18
|
from torch.utils._pytree import tree_map
|
17
19
|
|
18
|
-
config.cache_size_limit = 2
|
20
|
+
config.cache_size_limit = 2**16
|
19
21
|
|
20
22
|
compile_mode = "max-autotune-no-cudagraphs"
|
21
23
|
dynamic = False
|
22
24
|
compile_mode_recommended_to_none = None
|
23
25
|
zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
|
24
26
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
|
+
_cudnn_double_backward_pattern = re.compile(
|
28
|
+
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
29
|
+
)
|
30
|
+
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
31
|
+
_fd_error = (
|
32
|
+
"You can accelerate startup by globally enabling finite_differences first " #
|
33
|
+
"(via opt.finite_differences=True or by subclassing it)\n"
|
34
|
+
"Original Error: "
|
35
|
+
)
|
25
36
|
|
26
37
|
|
27
38
|
def decorator(func):
|
@@ -58,8 +69,17 @@ einsum_base = string.ascii_lowercase
|
|
58
69
|
|
59
70
|
|
60
71
|
@decorator_knowngood
|
61
|
-
def _compilable_schedule_free_(
|
62
|
-
|
72
|
+
def _compilable_schedule_free_(
|
73
|
+
p: List[Tensor],
|
74
|
+
z: List[Tensor],
|
75
|
+
ckp1: Tensor,
|
76
|
+
update: List[Tensor],
|
77
|
+
lr: Tensor,
|
78
|
+
beta1: Tensor,
|
79
|
+
decay: float,
|
80
|
+
grad: List[Tensor],
|
81
|
+
caution,
|
82
|
+
):
|
63
83
|
for op, oz, u_, g_ in zip(p, z, update, grad):
|
64
84
|
u_ = u_.view_as(op)
|
65
85
|
p_, z_, u_ = map(promote, (op, oz, u_))
|
@@ -74,9 +94,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
|
|
74
94
|
copy_stochastic_(oz, z_)
|
75
95
|
|
76
96
|
|
77
|
-
def schedule_free_(
|
78
|
-
|
79
|
-
|
97
|
+
def schedule_free_(
|
98
|
+
lr: float,
|
99
|
+
weight_lr_power: float,
|
100
|
+
weight_sum: float,
|
101
|
+
beta1: float,
|
102
|
+
parameters: List[Tensor],
|
103
|
+
z: List[Tensor],
|
104
|
+
update: List[Tensor],
|
105
|
+
grad: List[Tensor],
|
106
|
+
caution: bool = False,
|
107
|
+
r: float = 0.0,
|
108
|
+
step: int = 0,
|
109
|
+
decay: float = 0.0,
|
110
|
+
):
|
80
111
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
81
112
|
weight_sum = weight_sum + weight
|
82
113
|
|
@@ -149,7 +180,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
149
180
|
|
150
181
|
|
151
182
|
def beta_debias(beta, step):
|
152
|
-
return 1 - (1 - beta) / (1 - beta
|
183
|
+
return 1 - (1 - beta) / (1 - beta**step)
|
153
184
|
|
154
185
|
|
155
186
|
def eps_sqrt(item, eps):
|
@@ -157,8 +188,9 @@ def eps_sqrt(item, eps):
|
|
157
188
|
|
158
189
|
|
159
190
|
@decorator_knowngood
|
160
|
-
def _compilable_exp_avg_sq_(
|
161
|
-
|
191
|
+
def _compilable_exp_avg_sq_(
|
192
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
|
193
|
+
):
|
162
194
|
g32 = promote(grad)
|
163
195
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
164
196
|
|
@@ -219,8 +251,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
|
|
219
251
|
copy_stochastic_list_(gradients, g32)
|
220
252
|
|
221
253
|
|
222
|
-
def adaptive_gradient_clipping_(
|
223
|
-
|
254
|
+
def adaptive_gradient_clipping_(
|
255
|
+
parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8
|
256
|
+
):
|
224
257
|
if clip_val <= 0:
|
225
258
|
return gradients
|
226
259
|
parameters, gradients = list_guard(parameters, gradients)
|
@@ -259,9 +292,11 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
|
259
292
|
|
260
293
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
261
294
|
_ignore_warning(
|
262
|
-
"Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
|
295
|
+
"Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
|
296
|
+
)
|
263
297
|
_ignore_warning(
|
264
|
-
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
298
|
+
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
|
+
)
|
265
300
|
|
266
301
|
|
267
302
|
@decorator
|
@@ -408,7 +443,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
408
443
|
|
409
444
|
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
410
445
|
in_str = einsum_base[: exp_avg.dim()]
|
411
|
-
out_str = einsum_base[exp_avg.dim(): 2 * exp_avg.dim()]
|
446
|
+
out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()]
|
412
447
|
|
413
448
|
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
414
449
|
if not from_shampoo:
|
@@ -418,8 +453,9 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
418
453
|
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
419
454
|
|
420
455
|
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
421
|
-
exp_avg_new = torch.einsum(
|
422
|
-
*[q for q in new_qs if q is not None]
|
456
|
+
exp_avg_new = torch.einsum(
|
457
|
+
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
|
+
)
|
423
459
|
copy_stochastic_(exp_avg, exp_avg_new)
|
424
460
|
|
425
461
|
for q, q_new in zip(Q, new_qs):
|
@@ -546,6 +582,20 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
|
|
546
582
|
_compilable_stochastic_add_(x, y, alpha)
|
547
583
|
|
548
584
|
|
585
|
+
@decorator_knowngood
|
586
|
+
def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Tensor, divisor: Tensor):
|
587
|
+
for x_, y_ in zip(x, y):
|
588
|
+
x32 = promote(x_)
|
589
|
+
y32 = promote(y_)
|
590
|
+
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
591
|
+
|
592
|
+
|
593
|
+
def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
|
594
|
+
x, y = list_guard(x, y)
|
595
|
+
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
|
+
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
|
+
|
598
|
+
|
549
599
|
@decorator_knowngood
|
550
600
|
def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
551
601
|
for x_, y_ in zip(x, y):
|
@@ -594,6 +644,20 @@ def promote(x):
|
|
594
644
|
return x
|
595
645
|
|
596
646
|
|
647
|
+
def promote_detach(x, should_promote):
|
648
|
+
if x is None:
|
649
|
+
return x
|
650
|
+
if should_promote:
|
651
|
+
x = promote(x)
|
652
|
+
return x.detach()
|
653
|
+
|
654
|
+
|
655
|
+
def detach(x):
|
656
|
+
if isinstance(x, Tensor):
|
657
|
+
return x.detach()
|
658
|
+
return x
|
659
|
+
|
660
|
+
|
597
661
|
def min_dtype(xs: List[Tensor]):
|
598
662
|
dtypes = [x.dtype for x in xs]
|
599
663
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
@@ -647,25 +711,36 @@ def project(grad, Q, back: bool):
|
|
647
711
|
return grad
|
648
712
|
|
649
713
|
|
650
|
-
|
651
|
-
|
652
|
-
|
714
|
+
@contextlib.contextmanager
|
715
|
+
def patch_backward():
|
716
|
+
@contextlib.contextmanager
|
717
|
+
def _inner(module):
|
718
|
+
original = module.backward
|
653
719
|
|
654
|
-
|
655
|
-
closure: The closure function passed to the optimizer.
|
720
|
+
signature = inspect.signature(original)
|
656
721
|
|
657
|
-
|
658
|
-
|
659
|
-
|
722
|
+
def patched_backward(*args, **kwargs):
|
723
|
+
new_kwargs = signature.bind(*args)
|
724
|
+
new_kwargs.apply_defaults()
|
725
|
+
new_kwargs = new_kwargs.arguments
|
726
|
+
new_kwargs.update(kwargs)
|
727
|
+
new_kwargs["create_graph"] = True
|
728
|
+
return original(**new_kwargs)
|
660
729
|
|
661
|
-
|
662
|
-
|
663
|
-
|
730
|
+
module.backward = patched_backward
|
731
|
+
yield
|
732
|
+
module.backward = original
|
664
733
|
|
665
|
-
|
734
|
+
with _inner(torch.Tensor), _inner(torch.autograd):
|
735
|
+
yield
|
666
736
|
|
667
|
-
|
668
|
-
|
737
|
+
|
738
|
+
def hasattr_none(obj, name):
|
739
|
+
return getattr(obj, name, None) is not None
|
740
|
+
|
741
|
+
|
742
|
+
class ExactHVPFailed(ValueError):
|
743
|
+
pass
|
669
744
|
|
670
745
|
|
671
746
|
class StatefulOptimizer(torch.optim.Optimizer):
|
@@ -682,6 +757,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
682
757
|
precond_schedule: Union[Callable, float, None] = None
|
683
758
|
stochastic_schedule: bool = False
|
684
759
|
finite_differences: bool = False
|
760
|
+
fallback_to_finite_differences: bool = True
|
761
|
+
_fallback_enabled: bool = False
|
762
|
+
hvp_interval: int = 1 # grad is faster initially, hvp later
|
685
763
|
|
686
764
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
687
765
|
super().__init__(params, {**defaults, "foreach": foreach})
|
@@ -708,29 +786,46 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
708
786
|
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
709
787
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
710
788
|
|
711
|
-
def split_p_and_g_in_group(
|
712
|
-
|
789
|
+
def split_p_and_g_in_group(
|
790
|
+
self,
|
791
|
+
group: dict,
|
792
|
+
skip_none: bool = True,
|
793
|
+
should_promote: bool = True,
|
794
|
+
beta1: float = -1.0,
|
795
|
+
raw: bool = False,
|
796
|
+
):
|
713
797
|
for p in group["params"]:
|
798
|
+
grad = getattr(p, "grad", None)
|
799
|
+
if grad is None and skip_none:
|
800
|
+
continue
|
801
|
+
|
802
|
+
p.grad = None
|
803
|
+
|
804
|
+
if raw:
|
805
|
+
yield p, grad
|
806
|
+
continue
|
807
|
+
|
714
808
|
if p in self.mapping:
|
715
809
|
p_views = self.mapping[p]
|
716
810
|
else:
|
717
811
|
self.mapping[p] = p_views = merge_group(group, p)
|
718
812
|
|
719
|
-
|
720
|
-
|
813
|
+
vector = getattr(p, "vector", None)
|
814
|
+
hessian_vector = getattr(p, "hessian_vector", None)
|
815
|
+
p.vector = None
|
816
|
+
p.hessian_vector = None
|
721
817
|
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
818
|
+
grad, vs, hvs = [
|
819
|
+
[None] * len(p_views) if x is None else merge_group(group, x) #
|
820
|
+
for x in (grad, vector, hessian_vector)
|
821
|
+
]
|
726
822
|
|
727
|
-
for pv, g in zip(p_views, grad):
|
728
|
-
|
729
|
-
continue
|
730
|
-
if should_promote:
|
731
|
-
g = promote(g)
|
823
|
+
for pv, g, v, hv in zip(p_views, grad, vs, hvs):
|
824
|
+
g = promote_detach(g, should_promote)
|
732
825
|
if beta1 >= 0 and group.get("mars", False):
|
733
826
|
self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
|
827
|
+
pv.vector = promote_detach(v, should_promote)
|
828
|
+
pv.hessian_vector = promote_detach(hv, should_promote)
|
734
829
|
yield pv, g
|
735
830
|
|
736
831
|
def state_size(self) -> int:
|
@@ -794,6 +889,66 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
794
889
|
set_(self.state_(p)["param_ema"], p.data)
|
795
890
|
set_(p.data, ema_clone)
|
796
891
|
|
892
|
+
def _finite_differences_hvp(self, closure):
|
893
|
+
with torch.enable_grad():
|
894
|
+
loss = closure() # closure without retain_graph=True
|
895
|
+
|
896
|
+
grads = []
|
897
|
+
for group in self.param_groups:
|
898
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
899
|
+
grads.append(g)
|
900
|
+
p.vector = torch.randn_like(p)
|
901
|
+
p.orig = p.data.clone()
|
902
|
+
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
|
903
|
+
stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
|
904
|
+
|
905
|
+
with torch.enable_grad():
|
906
|
+
closure()
|
907
|
+
|
908
|
+
# we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
|
909
|
+
# this costs more memory, but the imprecision seems too severe to use the other method
|
910
|
+
for group in self.param_groups:
|
911
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
912
|
+
p.grad = grads.pop(0)
|
913
|
+
stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
|
914
|
+
p.hessian_vector = g
|
915
|
+
p.data.copy_(p.orig)
|
916
|
+
del p.orig
|
917
|
+
return loss
|
918
|
+
|
919
|
+
def _double_backward_hvp(self, closure):
|
920
|
+
with torch.enable_grad(), patch_backward():
|
921
|
+
loss = closure()
|
922
|
+
|
923
|
+
params, grads = [], []
|
924
|
+
for group in self.param_groups:
|
925
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
926
|
+
params.append(p)
|
927
|
+
grads.append(g)
|
928
|
+
|
929
|
+
if not params:
|
930
|
+
raise ValueError("No parameter has gradients")
|
931
|
+
|
932
|
+
vs = [torch.randn_like(p) for p in params]
|
933
|
+
with torch.enable_grad():
|
934
|
+
try:
|
935
|
+
hvs = torch.autograd.grad(grads, params, vs, create_graph=False, retain_graph=False, allow_unused=True)
|
936
|
+
except RuntimeError as e:
|
937
|
+
raise ExactHVPFailed(str(e.args))
|
938
|
+
|
939
|
+
unused = []
|
940
|
+
for p, g, v, hv in zip(params, grads, vs, hvs):
|
941
|
+
p.hessian_vector = detach(hv)
|
942
|
+
p.grad = detach(g)
|
943
|
+
p.vector = detach(v)
|
944
|
+
if hv is None:
|
945
|
+
unused.append(list(p.shape))
|
946
|
+
|
947
|
+
if unused:
|
948
|
+
raise ExactHVPFailed(f"Parameters with the following shapes have no 2nd order derivative: {unused}")
|
949
|
+
|
950
|
+
return loss
|
951
|
+
|
797
952
|
def _handle_closure(self, closure):
|
798
953
|
hessian_approx = self.hessian_approx and self._is_preconditioning
|
799
954
|
|
@@ -802,56 +957,41 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
802
957
|
raise ValueError("Hessian approximation requires a closure.")
|
803
958
|
return None
|
804
959
|
|
805
|
-
|
960
|
+
step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
|
961
|
+
if not hessian_approx or step % self.hvp_interval == 0:
|
806
962
|
with torch.enable_grad():
|
807
963
|
loss = closure()
|
808
964
|
return loss
|
809
965
|
|
810
|
-
if self.finite_differences:
|
811
|
-
|
812
|
-
loss = closure() # closure without retain_graph=True
|
813
|
-
|
814
|
-
grads = []
|
815
|
-
for group in self.param_groups:
|
816
|
-
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
817
|
-
grads.append(g)
|
818
|
-
p.vector = torch.randn_like(p)
|
819
|
-
p.orig = p.data.clone()
|
820
|
-
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
|
821
|
-
stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
|
822
|
-
else:
|
823
|
-
with torch.enable_grad():
|
824
|
-
loss = modify_closure(closure)
|
825
|
-
|
826
|
-
if self.finite_differences:
|
827
|
-
with torch.enable_grad():
|
828
|
-
closure()
|
829
|
-
|
830
|
-
# we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
|
831
|
-
# this costs more memory, but the imprecision seems too severe to use the other method
|
832
|
-
for group in self.param_groups:
|
833
|
-
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
834
|
-
p.grad = grads.pop(0)
|
835
|
-
stochastic_add_(g, p.grad, -1)
|
836
|
-
p.hessian_vector = g
|
837
|
-
p.data.copy_(p.orig)
|
838
|
-
del p.orig
|
839
|
-
else:
|
840
|
-
for group in self.param_groups:
|
841
|
-
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
842
|
-
p.grad = g
|
843
|
-
params, grads = zip(*[x for group in self.param_groups for x in
|
844
|
-
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
845
|
-
vs = [torch.randn_like(p) for p in params]
|
846
|
-
with torch.enable_grad():
|
847
|
-
hvs = torch.autograd.grad(grads, params, vs)
|
848
|
-
|
849
|
-
for p, g, v, hv in zip(params, grads, vs, hvs):
|
850
|
-
p.hessian_vector = hv
|
851
|
-
p.grad = g
|
852
|
-
p.vector = v
|
966
|
+
if self.finite_differences or self._fallback_enabled:
|
967
|
+
return self._finite_differences_hvp(closure)
|
853
968
|
|
854
|
-
|
969
|
+
try:
|
970
|
+
return self._double_backward_hvp(closure)
|
971
|
+
except NotImplementedError as e:
|
972
|
+
if not self.fallback_to_finite_differences:
|
973
|
+
raise
|
974
|
+
if not any(isinstance(arg, str) and _cudnn_double_backward_pattern.match(arg) for arg in e.args):
|
975
|
+
raise
|
976
|
+
warn_once(
|
977
|
+
"CUDNN doesn't support double-backward for some models (including RNNs). " #
|
978
|
+
f"Falling back to finite_differences.\n{_fd_error}{e}"
|
979
|
+
)
|
980
|
+
except RuntimeError as e:
|
981
|
+
if not self.fallback_to_finite_differences:
|
982
|
+
raise
|
983
|
+
if not any(isinstance(arg, str) and _torch_compile_double_backward_pattern.match(arg) for arg in e.args):
|
984
|
+
raise
|
985
|
+
warn_once(
|
986
|
+
f"torch.compile does not support double-backward. Disabling it may be beneficial, depending on "
|
987
|
+
f"the model.\n{_fd_error}{e}"
|
988
|
+
)
|
989
|
+
except ExactHVPFailed as e:
|
990
|
+
if not self.fallback_to_finite_differences:
|
991
|
+
raise
|
992
|
+
warn_once(f"Exact HVP calculation failed.\n{_fd_error}{e}")
|
993
|
+
self._fallback_enabled = True
|
994
|
+
return self._handle_closure(closure)
|
855
995
|
|
856
996
|
def step(self, closure: Optional[Callable] = None):
|
857
997
|
if self.precond_schedule is None:
|
@@ -867,7 +1007,11 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
867
1007
|
self._step(group)
|
868
1008
|
if self.use_ema:
|
869
1009
|
self.ema_update()
|
870
|
-
|
1010
|
+
for real, views in self.mapping.items():
|
1011
|
+
for tensor in (real, *views):
|
1012
|
+
for key in ("grad", "vector", "hessian_vector", "orig"):
|
1013
|
+
if hasattr(tensor, key):
|
1014
|
+
setattr(tensor, key, None)
|
871
1015
|
return loss
|
872
1016
|
|
873
1017
|
|
@@ -887,8 +1031,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
|
887
1031
|
|
888
1032
|
|
889
1033
|
@decorator_knowngood
|
890
|
-
def _compilable_adam_(
|
891
|
-
|
1034
|
+
def _compilable_adam_(
|
1035
|
+
exp_avg: List[Tensor],
|
1036
|
+
exp_avg_sq: List[Tensor],
|
1037
|
+
grad: List[Tensor],
|
1038
|
+
beta1: Tensor,
|
1039
|
+
beta2: Tensor,
|
1040
|
+
step: Tensor,
|
1041
|
+
eps: Tensor,
|
1042
|
+
):
|
892
1043
|
beta1 = beta_debias(beta1, step)
|
893
1044
|
beta2 = beta_debias(beta2, step)
|
894
1045
|
|
@@ -899,8 +1050,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
899
1050
|
copy_stochastic_list_(grad, u32)
|
900
1051
|
|
901
1052
|
|
902
|
-
def adam_(
|
903
|
-
|
1053
|
+
def adam_(
|
1054
|
+
exp_avg: List[Tensor],
|
1055
|
+
exp_avg_sq: List[Tensor],
|
1056
|
+
grad: List[Tensor],
|
1057
|
+
beta1: float,
|
1058
|
+
beta2: float,
|
1059
|
+
step: int,
|
1060
|
+
eps: float = 1e-8,
|
1061
|
+
):
|
904
1062
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
905
1063
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
906
1064
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -908,9 +1066,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
|
|
908
1066
|
|
909
1067
|
|
910
1068
|
@decorator_knowngood
|
911
|
-
def _fused_compilable_adam_(
|
912
|
-
|
913
|
-
|
1069
|
+
def _fused_compilable_adam_(
|
1070
|
+
y: List[Tensor],
|
1071
|
+
exp_avg: List[Tensor],
|
1072
|
+
exp_avg_sq: List[Tensor],
|
1073
|
+
update: List[Tensor],
|
1074
|
+
grad: List[Tensor],
|
1075
|
+
beta1: Tensor,
|
1076
|
+
beta2: Tensor,
|
1077
|
+
step: Tensor,
|
1078
|
+
decay: Tensor,
|
1079
|
+
lr: Tensor,
|
1080
|
+
eps: Tensor,
|
1081
|
+
caution: bool,
|
1082
|
+
):
|
914
1083
|
beta1 = beta_debias(beta1, step)
|
915
1084
|
beta2 = beta_debias(beta2, step)
|
916
1085
|
|
@@ -921,17 +1090,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
921
1090
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
922
1091
|
|
923
1092
|
|
924
|
-
def fused_adam_(
|
925
|
-
|
926
|
-
|
1093
|
+
def fused_adam_(
|
1094
|
+
y: List[Tensor],
|
1095
|
+
exp_avg: List[Tensor],
|
1096
|
+
exp_avg_sq: List[Tensor],
|
1097
|
+
update: List[Tensor],
|
1098
|
+
grad: List[Tensor],
|
1099
|
+
beta1: float,
|
1100
|
+
beta2: float,
|
1101
|
+
step: int,
|
1102
|
+
lr: float,
|
1103
|
+
eps: float,
|
1104
|
+
decay: float,
|
1105
|
+
caution: bool,
|
1106
|
+
):
|
927
1107
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
928
1108
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
929
1109
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
930
1110
|
|
931
1111
|
|
932
1112
|
@decorator_knowngood
|
933
|
-
def _compilable_laprop_(
|
934
|
-
|
1113
|
+
def _compilable_laprop_(
|
1114
|
+
exp_avg: List[Tensor],
|
1115
|
+
exp_avg_sq: List[Tensor],
|
1116
|
+
grad: List[Tensor],
|
1117
|
+
beta1: Tensor,
|
1118
|
+
beta2: Tensor,
|
1119
|
+
step: Tensor,
|
1120
|
+
eps: Tensor,
|
1121
|
+
):
|
935
1122
|
beta1 = beta_debias(beta1, step)
|
936
1123
|
beta2 = beta_debias(beta2, step)
|
937
1124
|
|
@@ -942,8 +1129,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
942
1129
|
copy_stochastic_list_(grad, gp32)
|
943
1130
|
|
944
1131
|
|
945
|
-
def laprop_(
|
946
|
-
|
1132
|
+
def laprop_(
|
1133
|
+
exp_avg: List[Tensor],
|
1134
|
+
exp_avg_sq: List[Tensor],
|
1135
|
+
grad: List[Tensor],
|
1136
|
+
beta1: float,
|
1137
|
+
beta2: float,
|
1138
|
+
step: int,
|
1139
|
+
eps: float = 1e-8,
|
1140
|
+
):
|
947
1141
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
948
1142
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
949
1143
|
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -951,9 +1145,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
951
1145
|
|
952
1146
|
|
953
1147
|
@decorator_knowngood
|
954
|
-
def _fused_compilable_laprop_(
|
955
|
-
|
956
|
-
|
1148
|
+
def _fused_compilable_laprop_(
|
1149
|
+
y: List[Tensor],
|
1150
|
+
exp_avg: List[Tensor],
|
1151
|
+
exp_avg_sq: List[Tensor],
|
1152
|
+
update: List[Tensor],
|
1153
|
+
grad: List[Tensor],
|
1154
|
+
beta1: Tensor,
|
1155
|
+
beta2: Tensor,
|
1156
|
+
step: Tensor,
|
1157
|
+
lr: Tensor,
|
1158
|
+
decay: Tensor,
|
1159
|
+
caution: bool,
|
1160
|
+
eps: Tensor,
|
1161
|
+
):
|
957
1162
|
beta1 = beta_debias(beta1, step)
|
958
1163
|
beta2 = beta_debias(beta2, step)
|
959
1164
|
|
@@ -964,9 +1169,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
964
1169
|
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
965
1170
|
|
966
1171
|
|
967
|
-
def fused_laprop_(
|
968
|
-
|
969
|
-
|
1172
|
+
def fused_laprop_(
|
1173
|
+
y: List[Tensor],
|
1174
|
+
exp_avg: List[Tensor],
|
1175
|
+
exp_avg_sq: List[Tensor],
|
1176
|
+
update: List[Tensor],
|
1177
|
+
grad: List[Tensor],
|
1178
|
+
beta1: float,
|
1179
|
+
beta2: float,
|
1180
|
+
step: int,
|
1181
|
+
lr: float,
|
1182
|
+
decay: float,
|
1183
|
+
caution: bool,
|
1184
|
+
eps: float = 1e-8,
|
1185
|
+
):
|
970
1186
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
971
1187
|
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
972
1188
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
@@ -1040,8 +1256,9 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
1040
1256
|
|
1041
1257
|
|
1042
1258
|
@decorator_knowngood
|
1043
|
-
def _compilable_update_(
|
1044
|
-
|
1259
|
+
def _compilable_update_(
|
1260
|
+
p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]]
|
1261
|
+
):
|
1045
1262
|
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
1046
1263
|
u_ = promote(u_.view_as(p_))
|
1047
1264
|
p32_ = promote(p_)
|
@@ -1051,8 +1268,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
|
|
1051
1268
|
copy_stochastic_(p_, p32_)
|
1052
1269
|
|
1053
1270
|
|
1054
|
-
def update_param_(
|
1055
|
-
|
1271
|
+
def update_param_(
|
1272
|
+
param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None
|
1273
|
+
):
|
1056
1274
|
param, update, grad = list_guard(param, update, grad)
|
1057
1275
|
lr = scalar_guard(lr, param[0])
|
1058
1276
|
if not caution:
|
@@ -1076,28 +1294,74 @@ def _max_idx(x: List[int]):
|
|
1076
1294
|
|
1077
1295
|
|
1078
1296
|
@decorator_knowngood
|
1079
|
-
def
|
1080
|
-
|
1297
|
+
def stable_exp(x: Tensor):
|
1298
|
+
# fp16:
|
1299
|
+
# exp(x) is stable in [-17, 11]
|
1300
|
+
# `stable_exp` extends to [-17, 17]
|
1301
|
+
# average error (in [-10, 10]) increased from 2.288e-3 to 2.299e-3
|
1302
|
+
# fp32:
|
1303
|
+
# exp(x) is stable in [-103, 88]
|
1304
|
+
# `stable_exp` extends to [-103, 103]
|
1305
|
+
# average error (in [-87, 87]) reduced from 3.309-06 to 3.224-06
|
1306
|
+
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1081
1307
|
|
1082
1308
|
|
1083
1309
|
@decorator_knowngood
|
1084
|
-
def
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1310
|
+
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
|
+
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
+
log_x = x.double().abs().clamp(min=eps).log()
|
1313
|
+
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
+
return stable_exp(-log_mean_x_pow / pow / 2)
|
1315
|
+
|
1316
|
+
|
1317
|
+
@decorator_knowngood
|
1318
|
+
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
|
+
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
+
log_x = x.double().abs().clamp(min=eps).log()
|
1321
|
+
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
+
|
1323
|
+
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
+
x_normed = x_normed / pow0 / 2
|
1325
|
+
|
1326
|
+
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
+
y_normed = y_normed / pow1 / 2
|
1088
1328
|
|
1329
|
+
return stable_exp(x_normed - y_normed)
|
1089
1330
|
|
1090
|
-
|
1331
|
+
|
1332
|
+
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
|
1333
|
+
automatic_scale = True
|
1334
|
+
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1091
1335
|
if scale is not None:
|
1336
|
+
automatic_scale = False
|
1092
1337
|
warn_once(
|
1093
|
-
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1094
|
-
|
1338
|
+
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1339
|
+
)
|
1340
|
+
if scale_scale is not None and scale_scale != 1:
|
1095
1341
|
warn_once(
|
1096
|
-
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
|
1342
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
|
1343
|
+
)
|
1344
|
+
elif hessian_vector is None:
|
1345
|
+
scale = mean_root(grad, 4) * scale_scale
|
1346
|
+
else:
|
1347
|
+
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1348
|
+
if isinstance(scale, torch.Tensor):
|
1349
|
+
scale = scale.item() # slow, but necessary
|
1350
|
+
if np.isfinite(scale):
|
1351
|
+
if scale > scale_max or scale < 1 / scale_max:
|
1352
|
+
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1097
1353
|
return scale
|
1098
|
-
if
|
1099
|
-
|
1100
|
-
|
1354
|
+
if not automatic_scale:
|
1355
|
+
raise ValueError("The manually set precond_init_scale is not finite")
|
1356
|
+
|
1357
|
+
for x in (grad, hessian_vector, vector):
|
1358
|
+
if x is None:
|
1359
|
+
continue
|
1360
|
+
if torch.allclose(x, torch.zeros_like(x)).item():
|
1361
|
+
raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
|
1362
|
+
if not torch.isfinite(x).all().item():
|
1363
|
+
raise ValueError("Grad or HVP is not finite")
|
1364
|
+
raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1101
1365
|
|
1102
1366
|
|
1103
1367
|
def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
|
@@ -1108,8 +1372,9 @@ def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None)
|
|
1108
1372
|
return U, V, d
|
1109
1373
|
|
1110
1374
|
|
1111
|
-
def init_Q_exprs(
|
1112
|
-
|
1375
|
+
def init_Q_exprs(
|
1376
|
+
grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
|
1377
|
+
):
|
1113
1378
|
"""
|
1114
1379
|
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
1115
1380
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
@@ -1147,8 +1412,10 @@ def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory
|
|
1147
1412
|
elif memory_save_mode == "all_diag":
|
1148
1413
|
dim_diag = [True for _ in shape]
|
1149
1414
|
else:
|
1150
|
-
raise ValueError(
|
1151
|
-
|
1415
|
+
raise ValueError(
|
1416
|
+
f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
|
1417
|
+
"[None, 'one_diag', 'all_diag', 'smart_one_diag']"
|
1418
|
+
)
|
1152
1419
|
|
1153
1420
|
Q = []
|
1154
1421
|
piece1A, piece2A, piece3A = ([], "", "")
|
@@ -1213,8 +1480,16 @@ def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
|
1213
1480
|
return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1214
1481
|
|
1215
1482
|
|
1216
|
-
def update_lra_precond_(
|
1217
|
-
|
1483
|
+
def update_lra_precond_(
|
1484
|
+
U: List[Tensor],
|
1485
|
+
V: List[Tensor],
|
1486
|
+
d: List[Tensor],
|
1487
|
+
vector: Tensor,
|
1488
|
+
hessian_vector: Tensor,
|
1489
|
+
eps: float,
|
1490
|
+
step: float,
|
1491
|
+
delayed: bool,
|
1492
|
+
):
|
1218
1493
|
"""
|
1219
1494
|
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1220
1495
|
"""
|
@@ -1293,7 +1568,7 @@ def lra_precond(U, V, d, g):
|
|
1293
1568
|
|
1294
1569
|
|
1295
1570
|
@decorator_knowngood
|
1296
|
-
def dampen_grad(g: Tensor, damp: float = 2
|
1571
|
+
def dampen_grad(g: Tensor, damp: float = 2**-13):
|
1297
1572
|
# https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
|
1298
1573
|
v = torch.randn_like(g)
|
1299
1574
|
return v, g + damp * g.abs().mean() * v
|
@@ -1306,7 +1581,7 @@ def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor,
|
|
1306
1581
|
update = update.flatten()
|
1307
1582
|
for p in params:
|
1308
1583
|
size = p.numel()
|
1309
|
-
copy_stochastic_(p, update[start: start + size].view_as(p))
|
1584
|
+
copy_stochastic_(p, update[start : start + size].view_as(p))
|
1310
1585
|
start += size
|
1311
1586
|
|
1312
1587
|
|
@@ -1316,7 +1591,7 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
|
|
1316
1591
|
update = update.flatten()
|
1317
1592
|
for p in params:
|
1318
1593
|
size = p.numel()
|
1319
|
-
copy_stochastic_(p, update[start: start + size].view_as(p))
|
1594
|
+
copy_stochastic_(p, update[start : start + size].view_as(p))
|
1320
1595
|
start += size
|
1321
1596
|
|
1322
1597
|
|
@@ -1326,7 +1601,7 @@ def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
|
1326
1601
|
update = update.flatten()
|
1327
1602
|
for p in params:
|
1328
1603
|
size = p.numel()
|
1329
|
-
stochastic_add_([p], [update[start: start + size].view_as(p)], alpha)
|
1604
|
+
stochastic_add_([p], [update[start : start + size].view_as(p)], alpha)
|
1330
1605
|
start += size
|
1331
1606
|
|
1332
1607
|
|
@@ -1337,16 +1612,19 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
1337
1612
|
update = update.flatten()
|
1338
1613
|
for p in params:
|
1339
1614
|
size = p.numel()
|
1340
|
-
outputs.append(update[start: start + size].view_as(p))
|
1615
|
+
outputs.append(update[start : start + size].view_as(p))
|
1341
1616
|
start += size
|
1342
1617
|
return outputs
|
1343
1618
|
|
1344
1619
|
|
1620
|
+
@decorator_knowngood
|
1345
1621
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1346
|
-
|
1622
|
+
last_dim = x[0].shape[-remaining:] if remaining else []
|
1623
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1347
1624
|
|
1348
1625
|
|
1349
|
-
|
1626
|
+
@decorator_knowngood
|
1627
|
+
def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
1350
1628
|
vs = []
|
1351
1629
|
gs = []
|
1352
1630
|
for g_ in g:
|
@@ -1356,22 +1634,27 @@ def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
|
|
1356
1634
|
return flatten(vs), flatten(gs)
|
1357
1635
|
|
1358
1636
|
|
1359
|
-
|
1637
|
+
@decorator_knowngood
|
1638
|
+
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1639
|
+
md = min_dtype(args)
|
1640
|
+
return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1641
|
+
|
1642
|
+
|
1643
|
+
def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
1360
1644
|
order = G.dim()
|
1361
|
-
if
|
1362
|
-
|
1363
|
-
conjB =
|
1364
|
-
|
1365
|
-
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
1366
|
-
Q = [promote(q) for q in Q]
|
1645
|
+
if order > 1:
|
1646
|
+
conjB = conjB.view_as(G).permute(*range(1, order), 0)
|
1647
|
+
conjB = conjB.to(promote(G.dtype))
|
1648
|
+
A = casted_einsum(exprA, *Q, G)
|
1367
1649
|
for i, q in enumerate(Q):
|
1650
|
+
q = promote(q)
|
1368
1651
|
if q.dim() <= 1:
|
1369
1652
|
conjB /= q
|
1370
1653
|
else:
|
1371
|
-
|
1372
|
-
|
1654
|
+
solved = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)).contiguous(), upper=True, left=False)
|
1655
|
+
conjB = solved.reshape_as(conjB)
|
1373
1656
|
if i < order - 1:
|
1374
|
-
conjB =
|
1657
|
+
conjB = conjB.transpose(i, -1)
|
1375
1658
|
return A, conjB
|
1376
1659
|
|
1377
1660
|
|
@@ -1407,9 +1690,12 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
|
1407
1690
|
term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
|
1408
1691
|
term1 = torch.mm(term1, q.to(term1.dtype))
|
1409
1692
|
if store_triu_as_line:
|
1410
|
-
term1 = triu_to_line([term1])[0][1]
|
1411
|
-
|
1412
|
-
|
1693
|
+
term1 = triu_to_line([term1])[0][1] # Convert update to line format
|
1694
|
+
# Apply update directly to the tensor part of the state tuple o[1]
|
1695
|
+
stochastic_add_(o[1], term1, -1)
|
1696
|
+
else:
|
1697
|
+
# Apply update to the state tensor o
|
1698
|
+
stochastic_add_(o, term1, -1)
|
1413
1699
|
|
1414
1700
|
|
1415
1701
|
@decorator_knowngood
|
@@ -1619,8 +1905,9 @@ def warn_once(msg):
|
|
1619
1905
|
_warned.add(msg)
|
1620
1906
|
|
1621
1907
|
|
1622
|
-
def psgd_should_update(
|
1623
|
-
|
1908
|
+
def psgd_should_update(
|
1909
|
+
group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
|
1910
|
+
):
|
1624
1911
|
group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
|
1625
1912
|
if not isinstance(prob, float):
|
1626
1913
|
prob = prob(group[f"{name}_prob_step"])
|
@@ -1632,8 +1919,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1632
1919
|
|
1633
1920
|
|
1634
1921
|
@decorator_knowngood
|
1635
|
-
def precond_grad_cached_(
|
1636
|
-
|
1922
|
+
def precond_grad_cached_(
|
1923
|
+
expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
|
1924
|
+
):
|
1637
1925
|
if caution:
|
1638
1926
|
ea = _compilable_cautioning(grad, ea)
|
1639
1927
|
md = min_dtype(list(cached_q) + [ea])
|
@@ -1753,12 +2041,79 @@ def merge_group(group, *tensors):
|
|
1753
2041
|
|
1754
2042
|
out = []
|
1755
2043
|
for t in tensors:
|
1756
|
-
append_or_extend(
|
1757
|
-
|
1758
|
-
|
2044
|
+
append_or_extend(
|
2045
|
+
out,
|
2046
|
+
dim_merger(
|
2047
|
+
t,
|
2048
|
+
group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
|
2049
|
+
group.get("split", False),
|
2050
|
+
),
|
2051
|
+
)
|
1759
2052
|
return out
|
1760
2053
|
|
1761
2054
|
|
2055
|
+
@decorator_knowngood
|
2056
|
+
def _compilable_d_adapt_(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
|
2057
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2058
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2059
|
+
next_d = d * (g * s).sum()
|
2060
|
+
s = s + u * d
|
2061
|
+
next_d = next_d / s.abs().sum()
|
2062
|
+
next_d = torch.maximum(next_d, d)
|
2063
|
+
copy_stochastic_(u_, u * d)
|
2064
|
+
copy_stochastic_(d_, next_d)
|
2065
|
+
copy_stochastic_(s_, s)
|
2066
|
+
|
2067
|
+
|
2068
|
+
def d_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
|
2069
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2070
|
+
_compilable_d_adapt_(grads, update, state, delta)
|
2071
|
+
|
2072
|
+
|
2073
|
+
@decorator_knowngood
|
2074
|
+
def _compilable_lr_adapt_(
|
2075
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
|
2076
|
+
):
|
2077
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2078
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2079
|
+
lr_grad = d.sigmoid()
|
2080
|
+
lr_grad = lr_grad * (1 - lr_grad)
|
2081
|
+
lr_grad = lr_grad * (s * g).mean()
|
2082
|
+
d = d - lr_grad * lr_lr
|
2083
|
+
copy_stochastic_(d_, d)
|
2084
|
+
copy_stochastic_(u_, u * d.sigmoid())
|
2085
|
+
copy_stochastic_(s_, u)
|
2086
|
+
|
2087
|
+
|
2088
|
+
def lr_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float):
|
2089
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2090
|
+
lr_lr = scalar_guard(lr_lr, grads[0])
|
2091
|
+
_compilable_lr_adapt_(grads, update, state, delta, lr_lr)
|
2092
|
+
|
2093
|
+
|
2094
|
+
@decorator_knowngood
|
2095
|
+
def _compilable_pointwise_lr_adapt_(
|
2096
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
|
2097
|
+
):
|
2098
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2099
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2100
|
+
lr_grad = d.sigmoid()
|
2101
|
+
lr_grad = lr_grad * (1 - lr_grad)
|
2102
|
+
lr_grad = lr_grad * s * g
|
2103
|
+
d = d - lr_grad * lr_lr
|
2104
|
+
copy_stochastic_(d_, d)
|
2105
|
+
copy_stochastic_(u_, u * d.sigmoid())
|
2106
|
+
copy_stochastic_(s_, u)
|
2107
|
+
|
2108
|
+
|
2109
|
+
def pointwise_lr_adaptation(
|
2110
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float
|
2111
|
+
):
|
2112
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2113
|
+
lr_lr = scalar_guard(lr_lr, grads[0])
|
2114
|
+
_compilable_lr_adapt_(grads, update, state, delta, lr_lr)
|
2115
|
+
|
2116
|
+
|
1762
2117
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1763
2118
|
optimizers = {}
|
1764
2119
|
|
@@ -1781,8 +2136,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1781
2136
|
|
1782
2137
|
o = optimizer(parameters, *args, **kwargs)
|
1783
2138
|
step_fn = o.step
|
1784
|
-
o.step = functools.partial(
|
1785
|
-
msg="You're trying to call `step` on a fused optimizer. This will not do anything."
|
2139
|
+
o.step = functools.partial(
|
2140
|
+
warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything."
|
2141
|
+
)
|
1786
2142
|
|
1787
2143
|
def _step(p: Tensor):
|
1788
2144
|
seen_params.add(p)
|