heavyball 1.7.1__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -16
- heavyball/chainable.py +338 -190
- heavyball/helpers.py +804 -0
- heavyball/utils.py +813 -252
- heavyball-2.0.0.dev0.dist-info/METADATA +109 -0
- heavyball-2.0.0.dev0.dist-info/RECORD +9 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/WHEEL +1 -1
- heavyball/optimizations/__init__.py +0 -38
- heavyball/optimizations/integrator.py +0 -169
- heavyball/optimizations/optimizations.py +0 -329
- heavyball-1.7.1.dist-info/METADATA +0 -939
- heavyball-1.7.1.dist-info/RECORD +0 -11
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dev0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,28 +1,28 @@
|
|
1
|
+
import collections
|
1
2
|
import contextlib
|
2
3
|
import functools
|
3
4
|
import gc
|
4
5
|
import inspect
|
5
6
|
import math
|
7
|
+
import pickle
|
6
8
|
import random
|
7
9
|
import re
|
8
10
|
import string
|
9
11
|
import warnings
|
10
|
-
from typing import Callable, List, Optional, Tuple, Union
|
12
|
+
from typing import Callable, List, Literal, Optional, Tuple, Union
|
11
13
|
|
12
14
|
import numpy as np
|
13
15
|
import torch
|
14
16
|
from torch import Tensor
|
15
|
-
from torch._dynamo import config
|
16
17
|
from torch._dynamo.exc import TorchDynamoException
|
17
18
|
from torch.backends import cudnn, opt_einsum
|
18
19
|
from torch.utils._pytree import tree_map
|
19
20
|
|
20
|
-
config.cache_size_limit = 2**16
|
21
|
-
|
22
21
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
22
|
dynamic = False
|
24
23
|
compile_mode_recommended_to_none = None
|
25
|
-
zeroth_power_mode = "
|
24
|
+
zeroth_power_mode = "newtonschulz"
|
25
|
+
precise_zeroth_power_mode = "qr" # or svd
|
26
26
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
27
|
_cudnn_double_backward_pattern = re.compile(
|
28
28
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -50,7 +50,7 @@ def decorator(func):
|
|
50
50
|
return _fn
|
51
51
|
|
52
52
|
|
53
|
-
def decorator_knowngood(func: Callable):
|
53
|
+
def decorator_knowngood(func: Callable, fullgraph: bool = True):
|
54
54
|
compiled = None
|
55
55
|
|
56
56
|
@functools.wraps(func)
|
@@ -59,7 +59,7 @@ def decorator_knowngood(func: Callable):
|
|
59
59
|
return func(*args, **kwargs)
|
60
60
|
nonlocal compiled
|
61
61
|
if compiled is None:
|
62
|
-
compiled = torch.compile(fullgraph=
|
62
|
+
compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
|
63
63
|
return compiled(*args, **kwargs)
|
64
64
|
|
65
65
|
return _fn
|
@@ -68,6 +68,16 @@ def decorator_knowngood(func: Callable):
|
|
68
68
|
einsum_base = string.ascii_lowercase
|
69
69
|
|
70
70
|
|
71
|
+
@decorator_knowngood
|
72
|
+
def compiled_einsum(expr, *args):
|
73
|
+
"""
|
74
|
+
this is necessary to avoid the slowdown introduced by uncompiled einsum
|
75
|
+
uncompiled einsum is twice as slow if we add three 1-sized dimensions
|
76
|
+
for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
|
77
|
+
"""
|
78
|
+
return torch.einsum(expr, *args)
|
79
|
+
|
80
|
+
|
71
81
|
@decorator_knowngood
|
72
82
|
def _compilable_schedule_free_(
|
73
83
|
p: List[Tensor],
|
@@ -122,6 +132,47 @@ def schedule_free_(
|
|
122
132
|
return weight_sum
|
123
133
|
|
124
134
|
|
135
|
+
@decorator_knowngood
|
136
|
+
def _compilable_msam(
|
137
|
+
lr: Tensor,
|
138
|
+
beta1: Tensor,
|
139
|
+
param: List[Tensor],
|
140
|
+
z: List[Tensor],
|
141
|
+
update: List[Tensor],
|
142
|
+
grad: List[Tensor],
|
143
|
+
exp_avg: List[Tensor],
|
144
|
+
caution: bool,
|
145
|
+
decay: Tensor,
|
146
|
+
sam_step_size: Tensor,
|
147
|
+
):
|
148
|
+
exp_avg32 = _lerp(exp_avg, update, beta1)
|
149
|
+
for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
|
150
|
+
u_ = u_.view_as(z_)
|
151
|
+
z32_ = promote(z_)
|
152
|
+
if caution:
|
153
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
154
|
+
z32_ = z32_ * (1 - decay * lr) + u_ * -lr
|
155
|
+
copy_stochastic_(z_, z32_)
|
156
|
+
copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
|
157
|
+
|
158
|
+
|
159
|
+
def msam_(
|
160
|
+
lr: float,
|
161
|
+
beta1: float,
|
162
|
+
param: List[Tensor],
|
163
|
+
z: List[Tensor],
|
164
|
+
update: List[Tensor],
|
165
|
+
grad: List[Tensor],
|
166
|
+
exp_avg: List[Tensor],
|
167
|
+
caution: bool,
|
168
|
+
weight_decay: float,
|
169
|
+
sam_step_size: float,
|
170
|
+
):
|
171
|
+
param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
|
172
|
+
lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
|
173
|
+
_compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
|
174
|
+
|
175
|
+
|
125
176
|
def append_or_extend(base, new):
|
126
177
|
if isinstance(new, list):
|
127
178
|
base.extend(new)
|
@@ -161,7 +212,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
161
212
|
new_shape = [grad.shape[0], *new_shape[::-1]]
|
162
213
|
new_grad = grad.reshape(new_shape)
|
163
214
|
if not split:
|
164
|
-
return new_grad
|
215
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
165
216
|
|
166
217
|
grads = [new_grad]
|
167
218
|
for i, sh in reversed(list(enumerate(new_shape[:]))):
|
@@ -172,7 +223,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
172
223
|
continue
|
173
224
|
grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
|
174
225
|
if len(grads) == 1:
|
175
|
-
return new_grad
|
226
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
176
227
|
new_grads = []
|
177
228
|
for g in grads:
|
178
229
|
append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
|
@@ -279,16 +330,29 @@ def clean():
|
|
279
330
|
|
280
331
|
|
281
332
|
def _ignore_warning(msg):
|
282
|
-
warnings.filterwarnings("ignore", f".*{msg}.*")
|
333
|
+
warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
|
334
|
+
|
283
335
|
|
336
|
+
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
|
337
|
+
import opt_einsum as _opt_einsum
|
284
338
|
|
285
|
-
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
286
339
|
cudnn.benchmark = True
|
287
340
|
cudnn.deterministic = False
|
288
341
|
cudnn.benchmark_limit = benchmark_limit
|
289
342
|
torch.use_deterministic_algorithms(False)
|
290
343
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
291
|
-
opt_einsum.set_flags(True
|
344
|
+
opt_einsum.set_flags(True)
|
345
|
+
if einsum_strategy == "heavyball":
|
346
|
+
opt_einsum.strategy = "auto-hq"
|
347
|
+
choices = _opt_einsum.paths._AUTO_HQ_CHOICES
|
348
|
+
for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
|
349
|
+
if isinstance(fn, int):
|
350
|
+
fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
|
351
|
+
for i in range(max(choices.keys()), max_val):
|
352
|
+
if i not in choices:
|
353
|
+
choices[i] = fn
|
354
|
+
else:
|
355
|
+
opt_einsum.strategy = einsum_strategy
|
292
356
|
|
293
357
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
294
358
|
_ignore_warning(
|
@@ -297,6 +361,9 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
|
297
361
|
_ignore_warning(
|
298
362
|
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
363
|
)
|
364
|
+
_ignore_warning(
|
365
|
+
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
|
366
|
+
)
|
300
367
|
|
301
368
|
|
302
369
|
@decorator
|
@@ -316,15 +383,6 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
316
383
|
return X.to(G.dtype)
|
317
384
|
|
318
385
|
|
319
|
-
def ortho(x):
|
320
|
-
if zeroth_power_mode == "qr":
|
321
|
-
return torch.linalg.qr(x).Q
|
322
|
-
if zeroth_power_mode == "svd":
|
323
|
-
u, _s, v = torch.linalg.svd(x)
|
324
|
-
return u @ v.T
|
325
|
-
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
326
|
-
|
327
|
-
|
328
386
|
@decorator_knowngood
|
329
387
|
def _compilable_heavyball_momentum_(state, grad, beta):
|
330
388
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
@@ -377,7 +435,7 @@ def _compilable_grafting(magnitude, direction):
|
|
377
435
|
|
378
436
|
|
379
437
|
@decorator_knowngood
|
380
|
-
def
|
438
|
+
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
381
439
|
if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
|
382
440
|
y = zeropower_via_newtonschulz5(x, 5)
|
383
441
|
elif mode == "qr":
|
@@ -395,9 +453,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
|
395
453
|
y = _compilable_grafting(x, y)
|
396
454
|
else:
|
397
455
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
456
|
+
if out is None:
|
457
|
+
return y
|
458
|
+
|
398
459
|
set_(out, y)
|
399
460
|
|
400
461
|
|
462
|
+
def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
|
463
|
+
return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
|
464
|
+
|
465
|
+
|
401
466
|
@decorator_knowngood
|
402
467
|
def _compilable_scatter_set(target, source, index):
|
403
468
|
target[:] = source.contiguous()[index].reshape_as(target)
|
@@ -413,6 +478,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
413
478
|
:param Q: List of current eigenbases (updated in-place to Q_new).
|
414
479
|
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
415
480
|
"""
|
481
|
+
if exp_avg.dim() == 0: # preconditioning doesn't make sense here
|
482
|
+
Q.clear()
|
483
|
+
return
|
484
|
+
|
416
485
|
if isinstance(Q, list) and not Q:
|
417
486
|
return
|
418
487
|
|
@@ -430,10 +499,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
430
499
|
q_old = promote(q.data)
|
431
500
|
|
432
501
|
tmp = m @ q_old
|
433
|
-
est_eig =
|
502
|
+
est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
|
434
503
|
sort_idx = torch.argsort(est_eig, descending=True)
|
435
504
|
|
436
|
-
tmp[:, sort_idx]
|
505
|
+
tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
|
437
506
|
new_qs.append(tmp)
|
438
507
|
|
439
508
|
if exp_avg is None:
|
@@ -453,7 +522,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
453
522
|
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
454
523
|
|
455
524
|
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
456
|
-
exp_avg_new =
|
525
|
+
exp_avg_new = compiled_einsum(
|
457
526
|
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
527
|
)
|
459
528
|
copy_stochastic_(exp_avg, exp_avg_new)
|
@@ -568,6 +637,19 @@ def scalar_guard(*args):
|
|
568
637
|
return out
|
569
638
|
|
570
639
|
|
640
|
+
def broadcastable_list_guard(*xs):
|
641
|
+
xs = list_guard(*xs)
|
642
|
+
for x in xs:
|
643
|
+
if isinstance(x[0], Tensor):
|
644
|
+
ref = x[0]
|
645
|
+
break
|
646
|
+
else:
|
647
|
+
raise ValueError("No tensor-valued input given")
|
648
|
+
xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
|
649
|
+
max_len = max(len(x) for x in xs)
|
650
|
+
return [x if len(x) > 1 else x * max_len for x in xs]
|
651
|
+
|
652
|
+
|
571
653
|
@decorator_knowngood
|
572
654
|
def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
573
655
|
for x_, y_ in zip(x, y):
|
@@ -577,7 +659,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
577
659
|
|
578
660
|
|
579
661
|
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
580
|
-
x, y =
|
662
|
+
x, y = broadcastable_list_guard(x, y)
|
581
663
|
alpha = scalar_guard(alpha, x[0])
|
582
664
|
_compilable_stochastic_add_(x, y, alpha)
|
583
665
|
|
@@ -591,7 +673,7 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
|
|
591
673
|
|
592
674
|
|
593
675
|
def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
|
594
|
-
x, y =
|
676
|
+
x, y = broadcastable_list_guard(x, y)
|
595
677
|
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
678
|
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
679
|
|
@@ -605,7 +687,7 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
605
687
|
|
606
688
|
|
607
689
|
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
608
|
-
x, y =
|
690
|
+
x, y = broadcastable_list_guard(x, y)
|
609
691
|
_compilable_stochastic_multiply_(x, y)
|
610
692
|
|
611
693
|
|
@@ -624,7 +706,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
624
706
|
b = einsum_base[idx]
|
625
707
|
g0 = einsum_base[: grad.dim()]
|
626
708
|
g1 = g0.replace(b, b.upper())
|
627
|
-
outer_product =
|
709
|
+
outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
|
628
710
|
stochastic_lerp_(m, outer_product, 1 - beta)
|
629
711
|
|
630
712
|
|
@@ -706,7 +788,7 @@ def project(grad, Q, back: bool):
|
|
706
788
|
preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
707
789
|
if preconditioners:
|
708
790
|
out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
|
709
|
-
out =
|
791
|
+
out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
|
710
792
|
grad = out.to(grad.dtype)
|
711
793
|
return grad
|
712
794
|
|
@@ -714,24 +796,28 @@ def project(grad, Q, back: bool):
|
|
714
796
|
@contextlib.contextmanager
|
715
797
|
def patch_backward():
|
716
798
|
@contextlib.contextmanager
|
717
|
-
def
|
799
|
+
def patch_module(module):
|
718
800
|
original = module.backward
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
801
|
+
try:
|
802
|
+
signature = inspect.signature(original)
|
803
|
+
|
804
|
+
@functools.wraps(original)
|
805
|
+
def patched_backward(*args, **kwargs):
|
806
|
+
new_kwargs = signature.bind(*args)
|
807
|
+
new_kwargs.apply_defaults()
|
808
|
+
new_kwargs = new_kwargs.arguments
|
809
|
+
new_kwargs.update(kwargs)
|
810
|
+
new_kwargs["create_graph"] = True
|
811
|
+
return original(**new_kwargs)
|
812
|
+
|
813
|
+
module.backward = patched_backward
|
814
|
+
yield
|
815
|
+
finally:
|
816
|
+
module.backward = original
|
817
|
+
|
818
|
+
with contextlib.ExitStack() as stack:
|
819
|
+
stack.enter_context(patch_module(torch.Tensor))
|
820
|
+
stack.enter_context(patch_module(torch.autograd))
|
735
821
|
yield
|
736
822
|
|
737
823
|
|
@@ -743,6 +829,9 @@ class ExactHVPFailed(ValueError):
|
|
743
829
|
pass
|
744
830
|
|
745
831
|
|
832
|
+
use_default = object()
|
833
|
+
|
834
|
+
|
746
835
|
class StatefulOptimizer(torch.optim.Optimizer):
|
747
836
|
"""
|
748
837
|
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
@@ -755,7 +844,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
755
844
|
compile_step: bool = False
|
756
845
|
hessian_approx: bool = False
|
757
846
|
precond_schedule: Union[Callable, float, None] = None
|
758
|
-
stochastic_schedule: bool =
|
847
|
+
stochastic_schedule: bool | Literal[use_default] = use_default
|
759
848
|
finite_differences: bool = False
|
760
849
|
fallback_to_finite_differences: bool = True
|
761
850
|
_fallback_enabled: bool = False
|
@@ -765,18 +854,61 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
765
854
|
super().__init__(params, {**defaults, "foreach": foreach})
|
766
855
|
self.use_ema = use_ema
|
767
856
|
self.mapping = {}
|
768
|
-
self.
|
769
|
-
|
857
|
+
self.mapping_inverse = {}
|
858
|
+
|
859
|
+
if self.stochastic_schedule is use_default:
|
860
|
+
stochastic_schedule = None
|
861
|
+
for group in self.param_groups:
|
862
|
+
new = group.get("stochastic_schedule", stochastic_schedule)
|
863
|
+
if stochastic_schedule is not None and new != stochastic_schedule:
|
864
|
+
raise ValueError("All parameter groups must have the same stochastic_schedule.")
|
865
|
+
stochastic_schedule = new
|
866
|
+
self.stochastic_schedule = stochastic_schedule
|
867
|
+
|
868
|
+
self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
|
869
|
+
self.precond_rng = random.Random(0x12312)
|
770
870
|
self._is_preconditioning = None
|
771
871
|
|
772
872
|
if self.hessian_approx and self.compile_step:
|
773
873
|
raise ValueError("Hessian approximation can't be used with compile_step.")
|
774
874
|
|
875
|
+
self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
|
876
|
+
self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
|
877
|
+
self._init_mapping()
|
878
|
+
|
879
|
+
def _store_stats(self, state_dict: dict[str, any]):
|
880
|
+
state_dict["heavyball"] = {
|
881
|
+
"inner_group": self.inner_group,
|
882
|
+
"precond_rng": pickle.dumps(self.precond_rng),
|
883
|
+
"use_ema": self.use_ema,
|
884
|
+
"ema_decay": self.ema_decay,
|
885
|
+
"compile_step": self.compile_step,
|
886
|
+
"hessian_approx": self.hessian_approx,
|
887
|
+
"precond_schedule": pickle.dumps(self.precond_schedule),
|
888
|
+
"stochastic_schedule": self.stochastic_schedule,
|
889
|
+
"fallback_to_finite_differences": self.fallback_to_finite_differences,
|
890
|
+
"_fallback_enabled": self._fallback_enabled,
|
891
|
+
"hvp_interval": self.hvp_interval,
|
892
|
+
}
|
893
|
+
|
894
|
+
def _load_stats(self, state_dict):
|
895
|
+
sd = state_dict.pop("heavyball", {})
|
896
|
+
for k, v in sd.items():
|
897
|
+
if k in ("precond_rng", "precond_schedule"):
|
898
|
+
v = pickle.loads(v)
|
899
|
+
setattr(self, k, v)
|
900
|
+
|
775
901
|
def get_groups(self, group):
|
776
902
|
return [group]
|
777
903
|
|
778
|
-
|
779
|
-
|
904
|
+
@functools.lru_cache(maxsize=None)
|
905
|
+
def state_(self, arg: Tensor, fail: bool = True):
|
906
|
+
if not fail and arg not in self.mapping:
|
907
|
+
return {}
|
908
|
+
state_param, index = self.mapping_inverse[arg]
|
909
|
+
if state_param not in self.state:
|
910
|
+
self.state[state_param] = collections.defaultdict(dict)
|
911
|
+
return self.state[state_param][index]
|
780
912
|
|
781
913
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
782
914
|
for p, g in zip(p_list, g_list):
|
@@ -786,6 +918,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
786
918
|
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
787
919
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
788
920
|
|
921
|
+
def _init_mapping(self, group: dict | None = None):
|
922
|
+
if group is None:
|
923
|
+
for group in self.param_groups:
|
924
|
+
self._init_mapping(group)
|
925
|
+
return
|
926
|
+
|
927
|
+
for p in group["params"]:
|
928
|
+
if p not in self.mapping:
|
929
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
930
|
+
for i, pv in enumerate(p_views):
|
931
|
+
self.mapping_inverse[pv] = (p, i)
|
932
|
+
|
789
933
|
def split_p_and_g_in_group(
|
790
934
|
self,
|
791
935
|
group: dict,
|
@@ -809,6 +953,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
809
953
|
p_views = self.mapping[p]
|
810
954
|
else:
|
811
955
|
self.mapping[p] = p_views = merge_group(group, p)
|
956
|
+
for i, pv in enumerate(p_views):
|
957
|
+
self.mapping_inverse[pv] = (p, i)
|
812
958
|
|
813
959
|
vector = getattr(p, "vector", None)
|
814
960
|
hessian_vector = getattr(p, "hessian_vector", None)
|
@@ -957,8 +1103,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
957
1103
|
raise ValueError("Hessian approximation requires a closure.")
|
958
1104
|
return None
|
959
1105
|
|
960
|
-
step = self.
|
961
|
-
if not hessian_approx or step % self.hvp_interval == 0:
|
1106
|
+
step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
|
1107
|
+
if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
|
962
1108
|
with torch.enable_grad():
|
963
1109
|
loss = closure()
|
964
1110
|
return loss
|
@@ -997,12 +1143,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
997
1143
|
if self.precond_schedule is None:
|
998
1144
|
self._is_preconditioning = False
|
999
1145
|
else:
|
1000
|
-
self._is_preconditioning = psgd_should_update(self.
|
1146
|
+
self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
|
1001
1147
|
loss = self._handle_closure(closure)
|
1002
1148
|
|
1003
1149
|
# we assume that parameters are constant and that there are no excessive recompiles
|
1004
1150
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
1005
1151
|
for group in self.param_groups:
|
1152
|
+
if "param_count" not in group:
|
1153
|
+
group["param_count"] = sum(p.numel() for p in group["params"])
|
1006
1154
|
group["is_preconditioning"] = self._is_preconditioning
|
1007
1155
|
self._step(group)
|
1008
1156
|
if self.use_ema:
|
@@ -1306,74 +1454,115 @@ def stable_exp(x: Tensor):
|
|
1306
1454
|
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1307
1455
|
|
1308
1456
|
|
1457
|
+
def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
|
1458
|
+
# ln(mean(x ** pow) ** (1 / pow / 2))
|
1459
|
+
normalization = math.log(x.numel())
|
1460
|
+
x = x.double()
|
1461
|
+
x = x.abs()
|
1462
|
+
x = x.clamp(min=eps)
|
1463
|
+
x = x.log()
|
1464
|
+
x = x * pow
|
1465
|
+
x = x.flatten()
|
1466
|
+
x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
|
1467
|
+
x = x - normalization # sum -> mean (divide by x.numel() in log space)
|
1468
|
+
return x / pow / 2
|
1469
|
+
|
1470
|
+
|
1309
1471
|
@decorator_knowngood
|
1310
1472
|
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
1473
|
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
-
|
1313
|
-
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
-
return stable_exp(-log_mean_x_pow / pow / 2)
|
1474
|
+
return stable_exp(-_lse_mean(x, pow, eps))
|
1315
1475
|
|
1316
1476
|
|
1317
1477
|
@decorator_knowngood
|
1318
1478
|
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
1479
|
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
-
|
1321
|
-
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
-
|
1323
|
-
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
-
x_normed = x_normed / pow0 / 2
|
1480
|
+
return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
|
1325
1481
|
|
1326
|
-
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
-
y_normed = y_normed / pow1 / 2
|
1328
1482
|
|
1329
|
-
|
1483
|
+
class PrecondInitError(ValueError):
|
1484
|
+
pass
|
1330
1485
|
|
1331
1486
|
|
1332
|
-
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float =
|
1487
|
+
def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
|
1333
1488
|
automatic_scale = True
|
1334
1489
|
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1490
|
+
scale_scale = 1 if scale_scale is None else scale_scale
|
1491
|
+
|
1335
1492
|
if scale is not None:
|
1336
1493
|
automatic_scale = False
|
1337
1494
|
warn_once(
|
1338
1495
|
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1339
1496
|
)
|
1340
|
-
if scale_scale
|
1497
|
+
if scale_scale != 1:
|
1341
1498
|
warn_once(
|
1342
|
-
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly
|
1499
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
|
1500
|
+
)
|
1501
|
+
if scale_power is not None:
|
1502
|
+
warn_once(
|
1503
|
+
"precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
|
1343
1504
|
)
|
1344
1505
|
elif hessian_vector is None:
|
1345
1506
|
scale = mean_root(grad, 4) * scale_scale
|
1346
1507
|
else:
|
1347
1508
|
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1509
|
+
|
1510
|
+
if automatic_scale:
|
1511
|
+
scale_power = 0.5 if scale_power is None else scale_power
|
1512
|
+
scale = scale**scale_power
|
1513
|
+
|
1348
1514
|
if isinstance(scale, torch.Tensor):
|
1349
1515
|
scale = scale.item() # slow, but necessary
|
1516
|
+
|
1350
1517
|
if np.isfinite(scale):
|
1351
|
-
if scale > scale_max
|
1518
|
+
if scale > scale_max: # fallthrough to later checks
|
1352
1519
|
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1353
|
-
|
1520
|
+
else:
|
1521
|
+
return scale
|
1522
|
+
|
1354
1523
|
if not automatic_scale:
|
1355
|
-
raise
|
1524
|
+
raise PrecondInitError("The manually set precond_init_scale is not finite")
|
1356
1525
|
|
1357
1526
|
for x in (grad, hessian_vector, vector):
|
1358
1527
|
if x is None:
|
1359
1528
|
continue
|
1360
|
-
if torch.allclose(x, torch.zeros_like(x))
|
1361
|
-
raise
|
1529
|
+
if torch.allclose(x, torch.zeros_like(x)):
|
1530
|
+
raise PrecondInitError(
|
1531
|
+
f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
|
1532
|
+
)
|
1362
1533
|
if not torch.isfinite(x).all().item():
|
1363
|
-
raise
|
1364
|
-
|
1534
|
+
raise PrecondInitError("Grad or HVP is not finite")
|
1535
|
+
|
1536
|
+
if np.isfinite(scale):
|
1537
|
+
return scale
|
1538
|
+
|
1539
|
+
raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1365
1540
|
|
1366
1541
|
|
1367
|
-
def init_lra(
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1542
|
+
def init_lra(
|
1543
|
+
grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
|
1544
|
+
):
|
1545
|
+
# "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
|
1546
|
+
# https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
|
1547
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1548
|
+
uv_scale = (param_count * (rank + eps)) ** -0.5
|
1549
|
+
U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1550
|
+
V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1371
1551
|
d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
|
1372
1552
|
return U, V, d
|
1373
1553
|
|
1374
1554
|
|
1375
1555
|
def init_Q_exprs(
|
1376
|
-
grad,
|
1556
|
+
grad,
|
1557
|
+
scale,
|
1558
|
+
scale_scale,
|
1559
|
+
scale_power,
|
1560
|
+
max_size,
|
1561
|
+
min_ndim_triangular,
|
1562
|
+
memory_save_mode,
|
1563
|
+
hessian_vector,
|
1564
|
+
vector,
|
1565
|
+
dtype=None,
|
1377
1566
|
):
|
1378
1567
|
"""
|
1379
1568
|
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
@@ -1382,21 +1571,13 @@ def init_Q_exprs(
|
|
1382
1571
|
precond init scale computation from
|
1383
1572
|
https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
|
1384
1573
|
"""
|
1385
|
-
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1386
|
-
letters = string.ascii_lowercase + string.ascii_uppercase
|
1574
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1387
1575
|
dtype = dtype if dtype is not None else grad.dtype
|
1388
1576
|
shape = grad.shape
|
1389
1577
|
|
1390
1578
|
if len(shape) == 0: # scalar
|
1391
1579
|
Q = [scale * torch.ones_like(grad, dtype=dtype)]
|
1392
|
-
|
1393
|
-
exprGs = [",->"]
|
1394
|
-
exprP = ",,->"
|
1395
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1396
|
-
|
1397
|
-
# Tensor
|
1398
|
-
if len(shape) > 13:
|
1399
|
-
raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
|
1580
|
+
return Q
|
1400
1581
|
|
1401
1582
|
scale = scale ** (1 / len(shape))
|
1402
1583
|
|
@@ -1409,6 +1590,9 @@ def init_Q_exprs(
|
|
1409
1590
|
sorted_shape = sorted(shape)
|
1410
1591
|
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1411
1592
|
dim_diag[_max_idx(shape)] = True
|
1593
|
+
elif memory_save_mode == "one_triu":
|
1594
|
+
shape_ranks = np.argsort(np.argsort(shape)) # ranks
|
1595
|
+
dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
|
1412
1596
|
elif memory_save_mode == "all_diag":
|
1413
1597
|
dim_diag = [True for _ in shape]
|
1414
1598
|
else:
|
@@ -1418,66 +1602,90 @@ def init_Q_exprs(
|
|
1418
1602
|
)
|
1419
1603
|
|
1420
1604
|
Q = []
|
1421
|
-
piece1A, piece2A, piece3A = ([], "", "")
|
1422
|
-
exprGs = []
|
1423
|
-
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
|
1424
1605
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
1425
1606
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
1426
1607
|
# use diagonal matrix as preconditioner for this dim
|
1427
1608
|
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
|
1428
|
-
|
1429
|
-
piece1A.append(letters[i])
|
1430
|
-
piece2A = piece2A + letters[i]
|
1431
|
-
piece3A = piece3A + letters[i]
|
1432
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1433
|
-
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1434
|
-
exprGs.append(subscripts)
|
1435
|
-
piece1P.append(letters[i + 13])
|
1436
|
-
piece2P.append(letters[i + 13])
|
1437
|
-
piece3P = piece3P + letters[i + 13]
|
1438
|
-
piece4P = piece4P + letters[i + 13]
|
1439
1609
|
else:
|
1440
1610
|
# use triangular matrix as preconditioner for this dim
|
1441
1611
|
Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
|
1442
|
-
|
1443
|
-
piece2A = piece2A + letters[i + 13]
|
1444
|
-
piece3A = piece3A + letters[i]
|
1445
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1446
|
-
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1447
|
-
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
|
1448
|
-
exprGs.append(subscripts)
|
1449
|
-
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1450
|
-
piece1P.append(a + b)
|
1451
|
-
piece2P.append(a + c)
|
1452
|
-
piece3P = piece3P + c
|
1453
|
-
piece4P = piece4P + b
|
1454
|
-
|
1455
|
-
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
1456
|
-
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
1457
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1612
|
+
return Q
|
1458
1613
|
|
1459
1614
|
|
1460
|
-
@
|
1461
|
-
def psgd_balance_Q(
|
1462
|
-
norms =
|
1463
|
-
geometric_mean = norms
|
1464
|
-
|
1465
|
-
|
1615
|
+
@decorator_knowngood
|
1616
|
+
def psgd_balance_Q(Q):
|
1617
|
+
norms = [promote(q.norm(float("inf"))).log() for q in Q]
|
1618
|
+
geometric_mean = sum([n for n in norms]) / len(Q)
|
1619
|
+
for q, n in zip(Q, norms):
|
1620
|
+
q *= (geometric_mean - n).exp()
|
1466
1621
|
|
1467
1622
|
|
1468
|
-
@
|
1469
|
-
def
|
1470
|
-
u_norm =
|
1471
|
-
v_norm =
|
1472
|
-
scale = (u_norm / v_norm) ** 0.
|
1473
|
-
|
1474
|
-
|
1623
|
+
@decorator_knowngood
|
1624
|
+
def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
|
1625
|
+
u_norm = sum(u.square().sum().double() for u in U)
|
1626
|
+
v_norm = sum(v.square().sum().double() for v in V)
|
1627
|
+
scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
|
1628
|
+
scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
|
1629
|
+
stochastic_multiply_(U, [1 / scale] * len(U))
|
1630
|
+
stochastic_multiply_(V, [scale] * len(V))
|
1631
|
+
return multi_flatten((U, 1), (V, 1), (d, 0))
|
1475
1632
|
|
1476
1633
|
|
1477
1634
|
@decorator
|
1478
1635
|
def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
1479
1636
|
dtype = min_dtype([U, V, x])
|
1480
|
-
return x +
|
1637
|
+
return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1638
|
+
|
1639
|
+
|
1640
|
+
@decorator_knowngood
|
1641
|
+
def _compilable_d_step(
|
1642
|
+
d: Tensor,
|
1643
|
+
d_orig: List[Tensor],
|
1644
|
+
invQtv: Tensor,
|
1645
|
+
vector: Tensor,
|
1646
|
+
inverse_precond_vector: Tensor,
|
1647
|
+
hessian_vector: Tensor,
|
1648
|
+
precond_hessian_vector: Tensor,
|
1649
|
+
eps: Tensor,
|
1650
|
+
step: Tensor,
|
1651
|
+
delayed: bool,
|
1652
|
+
):
|
1653
|
+
precond_hessian_vector = promote(precond_hessian_vector)
|
1654
|
+
hessian_vector = promote(hessian_vector)
|
1655
|
+
vector = promote(vector)
|
1656
|
+
inverse_precond_vector = promote(inverse_precond_vector)
|
1657
|
+
invQtv = promote(invQtv)
|
1658
|
+
inverse_precond_vector = invQtv - inverse_precond_vector
|
1659
|
+
|
1660
|
+
nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
|
1661
|
+
|
1662
|
+
"""
|
1663
|
+
1) Sketching
|
1664
|
+
1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
|
1665
|
+
1.2) reduced-precision selection of largest element (halves memory traffic)
|
1666
|
+
2) Computation
|
1667
|
+
2.1) select relevant indices
|
1668
|
+
2.2) redo 1.1 in double precision for scalar values
|
1669
|
+
2.3) return high-precision normalized step-size
|
1670
|
+
overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
|
1671
|
+
improving precision
|
1672
|
+
"""
|
1673
|
+
a0 = promote(d) * precond_hessian_vector
|
1674
|
+
a1 = vector
|
1675
|
+
b0 = inverse_precond_vector / promote(d)
|
1676
|
+
b1 = hessian_vector
|
1677
|
+
|
1678
|
+
divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
|
1679
|
+
idx = divisor.bfloat16().flatten().argmax()
|
1680
|
+
a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
|
1681
|
+
b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
|
1682
|
+
divisor = (a * b).sqrt().clamp(min=eps)
|
1683
|
+
step = -step / divisor
|
1684
|
+
|
1685
|
+
# fused update(s)
|
1686
|
+
apply_flat_add(d_orig, nablaD, step)
|
1687
|
+
if not delayed:
|
1688
|
+
copy_stochastic_(d, promote(d) - nablaD * step)
|
1481
1689
|
|
1482
1690
|
|
1483
1691
|
def update_lra_precond_(
|
@@ -1489,13 +1697,14 @@ def update_lra_precond_(
|
|
1489
1697
|
eps: float,
|
1490
1698
|
step: float,
|
1491
1699
|
delayed: bool,
|
1700
|
+
precond_u: bool,
|
1492
1701
|
):
|
1493
1702
|
"""
|
1494
1703
|
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1495
1704
|
"""
|
1496
1705
|
U_orig, V_orig, d_orig = U, V, d
|
1497
1706
|
|
1498
|
-
U, V, d =
|
1707
|
+
U, V, d = _lra_flatten_and_balance(U, V, d)
|
1499
1708
|
|
1500
1709
|
dtype = min_dtype([U, V, vector, hessian_vector])
|
1501
1710
|
U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
|
@@ -1503,10 +1712,10 @@ def update_lra_precond_(
|
|
1503
1712
|
eps = scalar_guard(eps, vector)
|
1504
1713
|
|
1505
1714
|
Qh = low_rank_mm(U, V, d * hessian_vector)
|
1506
|
-
Ph =
|
1715
|
+
Ph = low_rank_mm(V, U, Qh)
|
1507
1716
|
rank = U.size(1)
|
1508
1717
|
|
1509
|
-
VtU =
|
1718
|
+
VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
|
1510
1719
|
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
1511
1720
|
IpVtU = I + VtU
|
1512
1721
|
invQtv = vector / d
|
@@ -1524,47 +1733,39 @@ def update_lra_precond_(
|
|
1524
1733
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1525
1734
|
|
1526
1735
|
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
1527
|
-
invPv =
|
1528
|
-
invPv = invPv / d
|
1736
|
+
invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
1529
1737
|
|
1530
|
-
|
1531
|
-
|
1532
|
-
divisor = divisor.add(eps).sqrt().max()
|
1533
|
-
d_step = step / divisor
|
1534
|
-
|
1535
|
-
apply_flat_add(d_orig, d * nablaD, -d_step)
|
1738
|
+
eps, step = scalar_guard(eps, step, vector)
|
1739
|
+
_compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
|
1536
1740
|
|
1537
1741
|
a, b = Qh, invQtv
|
1538
1742
|
|
1539
|
-
precond_u = random.random() < 0.5 # update either U or V, not both at the same time
|
1540
1743
|
precond = V if precond_u else U
|
1541
|
-
atV =
|
1542
|
-
btV =
|
1543
|
-
atVVt =
|
1544
|
-
btVVt =
|
1545
|
-
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()
|
1744
|
+
atV = compiled_einsum("b,br->r", a, precond) # o == one
|
1745
|
+
btV = compiled_einsum("b,br->r", b, precond)
|
1746
|
+
atVVt = compiled_einsum("r,br->b", atV, precond)
|
1747
|
+
btVVt = compiled_einsum("r,br->b", btV, precond)
|
1748
|
+
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
|
1546
1749
|
if precond_u:
|
1547
|
-
a =
|
1548
|
-
b =
|
1750
|
+
a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
|
1751
|
+
b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
|
1549
1752
|
else:
|
1550
|
-
a = a +
|
1551
|
-
b = b +
|
1552
|
-
a =
|
1553
|
-
b =
|
1753
|
+
a = a + compiled_einsum("br,r->b", V, atV)
|
1754
|
+
b = b + compiled_einsum("br,r->b", V, btV)
|
1755
|
+
a = compiled_einsum("b,r->br", a, atV)
|
1756
|
+
b = compiled_einsum("b,r->br", b, btV)
|
1554
1757
|
apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
|
1555
|
-
|
1556
1758
|
if not delayed:
|
1557
|
-
stochastic_add_([d], [d * nablaD], -d_step)
|
1558
1759
|
stochastic_add_([U if precond_u else V], [b - a], precond_step)
|
1559
1760
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1560
1761
|
|
1561
1762
|
|
1562
|
-
def lra_precond(U, V, d, g):
|
1763
|
+
def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
|
1563
1764
|
"""
|
1564
1765
|
As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
|
1565
1766
|
"""
|
1566
|
-
|
1567
|
-
return d * low_rank_mm(V, U,
|
1767
|
+
new_g = low_rank_mm(U, V, d * g)
|
1768
|
+
return d * low_rank_mm(V, U, new_g)
|
1568
1769
|
|
1569
1770
|
|
1570
1771
|
@decorator_knowngood
|
@@ -1575,16 +1776,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
|
|
1575
1776
|
|
1576
1777
|
|
1577
1778
|
@decorator_knowngood
|
1578
|
-
def
|
1579
|
-
|
1779
|
+
def _compilable_lra_update_(
|
1780
|
+
params: List[Tensor],
|
1781
|
+
update: List[Tensor],
|
1782
|
+
U: Tensor,
|
1783
|
+
V: Tensor,
|
1784
|
+
d: Tensor,
|
1785
|
+
lr: Tensor,
|
1786
|
+
decay: Tensor,
|
1787
|
+
caution: bool,
|
1788
|
+
grads: List[Tensor],
|
1789
|
+
):
|
1790
|
+
update = lra_precond(U, V, d, flatten(update))
|
1580
1791
|
start = 0
|
1581
1792
|
update = update.flatten()
|
1582
|
-
for p in params:
|
1793
|
+
for p, g in zip(params, grads):
|
1583
1794
|
size = p.numel()
|
1584
|
-
|
1795
|
+
update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
|
1585
1796
|
start += size
|
1586
1797
|
|
1587
1798
|
|
1799
|
+
def apply_lra_update(
|
1800
|
+
params: List[Tensor],
|
1801
|
+
update: Tensor,
|
1802
|
+
U: Tensor,
|
1803
|
+
V: Tensor,
|
1804
|
+
d: Tensor,
|
1805
|
+
lr: float,
|
1806
|
+
decay: float,
|
1807
|
+
caution: bool,
|
1808
|
+
grads: List[Tensor],
|
1809
|
+
):
|
1810
|
+
params, grads = list_guard(params, grads)
|
1811
|
+
lr, decay = scalar_guard(lr, decay, params[0])
|
1812
|
+
_compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
|
1813
|
+
|
1814
|
+
|
1588
1815
|
@decorator_knowngood
|
1589
1816
|
def apply_flat_update(params: List[Tensor], update: Tensor):
|
1590
1817
|
start = 0
|
@@ -1595,6 +1822,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
|
|
1595
1822
|
start += size
|
1596
1823
|
|
1597
1824
|
|
1825
|
+
@decorator_knowngood
|
1826
|
+
def zero_(x: List[Tensor]):
|
1827
|
+
for i in x:
|
1828
|
+
i.zero_()
|
1829
|
+
|
1830
|
+
|
1598
1831
|
@decorator_knowngood
|
1599
1832
|
def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
1600
1833
|
start = 0
|
@@ -1620,7 +1853,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
1620
1853
|
@decorator_knowngood
|
1621
1854
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1622
1855
|
last_dim = x[0].shape[-remaining:] if remaining else []
|
1623
|
-
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1856
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
|
1857
|
+
|
1858
|
+
|
1859
|
+
@decorator_knowngood
|
1860
|
+
def multi_flatten(*xs: Tuple[List[Tensor], int]):
|
1861
|
+
return [flatten(x, i) for x, i in xs]
|
1624
1862
|
|
1625
1863
|
|
1626
1864
|
@decorator_knowngood
|
@@ -1634,68 +1872,277 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
|
1634
1872
|
return flatten(vs), flatten(gs)
|
1635
1873
|
|
1636
1874
|
|
1637
|
-
@decorator_knowngood
|
1638
1875
|
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1639
1876
|
md = min_dtype(args)
|
1640
|
-
return
|
1877
|
+
return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1641
1878
|
|
1642
1879
|
|
1643
|
-
|
1644
|
-
|
1645
|
-
|
1646
|
-
|
1647
|
-
|
1648
|
-
A = casted_einsum(exprA, *Q, G)
|
1649
|
-
for i, q in enumerate(Q):
|
1880
|
+
@decorator_knowngood
|
1881
|
+
def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
|
1882
|
+
triangular_qs = []
|
1883
|
+
conjB = promote(conjB)
|
1884
|
+
for i, q in enumerate(Qs):
|
1650
1885
|
q = promote(q)
|
1651
1886
|
if q.dim() <= 1:
|
1652
|
-
conjB
|
1887
|
+
if conjB.ndim == 0:
|
1888
|
+
conjB = conjB / q
|
1889
|
+
else:
|
1890
|
+
shape = [1] * conjB.ndim
|
1891
|
+
shape[i] = -1
|
1892
|
+
conjB = conjB / q.view(shape)
|
1653
1893
|
else:
|
1654
|
-
|
1655
|
-
|
1656
|
-
|
1657
|
-
|
1894
|
+
triangular_qs.append((i, q))
|
1895
|
+
return triangular_qs, conjB
|
1896
|
+
|
1897
|
+
|
1898
|
+
@decorator_knowngood
|
1899
|
+
def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
|
1900
|
+
solved = solved.reshape(transposed_shape)
|
1901
|
+
solved = solved.transpose(-1, last_dim)
|
1902
|
+
solved = solved.reshape(original_shape)
|
1903
|
+
solved = solved.transpose(-1, new_dim)
|
1904
|
+
return solved.contiguous(), solved.shape
|
1905
|
+
|
1906
|
+
|
1907
|
+
def ndim_tuple(Q: list[Tensor]) -> tuple:
|
1908
|
+
return tuple(q.ndim for q in Q)
|
1909
|
+
|
1910
|
+
|
1911
|
+
def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
1912
|
+
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
|
1913
|
+
A = casted_einsum(exprA, *Q, G)
|
1914
|
+
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
1915
|
+
transposed_shape = original_shape = conjB.shape
|
1916
|
+
prev_i = -1
|
1917
|
+
qs, conjB = _psgd_calc_scalars_(Q, conjB)
|
1918
|
+
for i, tri_q in qs:
|
1919
|
+
conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
|
1920
|
+
prev_i = i
|
1921
|
+
conjB = solve(tri_q, conjB, upper=True, left=False)
|
1922
|
+
conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
|
1658
1923
|
return A, conjB
|
1659
1924
|
|
1660
1925
|
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
x
|
1669
|
-
|
1670
|
-
|
1671
|
-
|
1926
|
+
@decorator_knowngood
|
1927
|
+
def _random_projection(x: Tensor, scale: Optional[Tensor]):
|
1928
|
+
if scale is None:
|
1929
|
+
scale = x.norm(float("inf")).clamp(min=1e-8)
|
1930
|
+
k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
|
1931
|
+
norm = x.square().sum(0)
|
1932
|
+
indices = torch.topk(norm, k, largest=True).indices
|
1933
|
+
return x.index_select(1, indices).contiguous() / scale, scale
|
1934
|
+
|
1935
|
+
|
1936
|
+
def max_singular_value_exact(A, use_lobpcg: bool = False):
|
1937
|
+
try:
|
1938
|
+
if use_lobpcg:
|
1939
|
+
A = A @ A.T
|
1940
|
+
eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
|
1941
|
+
return eigval[0].sqrt()
|
1942
|
+
else:
|
1943
|
+
return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
|
1944
|
+
except torch.linalg.LinAlgError:
|
1945
|
+
return torch.zeros((), device=A.device, dtype=A.dtype)
|
1946
|
+
|
1947
|
+
|
1948
|
+
@decorator_knowngood
|
1949
|
+
def max_singular_value_power_iter(A: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
|
1950
|
+
"""
|
1951
|
+
Rayleigh quotient of row with the largest norm + optional power iterations
|
1952
|
+
"""
|
1953
|
+
x_norm, max_idx = A.norm(dim=1).max(dim=0)
|
1954
|
+
x = A.index_select(0, max_idx).flatten().contiguous()
|
1955
|
+
A = A / x_norm
|
1956
|
+
x = x / x_norm
|
1957
|
+
for _ in range(iterations):
|
1958
|
+
x = A.T.mv(A.mv(x)) # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
|
1959
|
+
x = x / x.norm()
|
1960
|
+
return (x @ A.T.mv(A.mv(x))).sqrt() * x_norm
|
1961
|
+
|
1962
|
+
|
1963
|
+
@decorator_knowngood
|
1964
|
+
def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
1965
|
+
"""
|
1966
|
+
Adapted from @evanatyourservice
|
1967
|
+
"""
|
1968
|
+
Y, max_abs = _random_projection(A, max_abs)
|
1969
|
+
Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
|
1970
|
+
Q = Q / max_abs
|
1971
|
+
Z = A.T @ Q
|
1972
|
+
W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
|
1973
|
+
sketch_norm = max_singular_value_exact(Z.T @ W)
|
1974
|
+
return sketch_norm * max_abs
|
1975
|
+
|
1976
|
+
|
1977
|
+
@decorator_knowngood
|
1978
|
+
def max_singular_value(
|
1979
|
+
A: Tensor, max_abs: Optional[Tensor], max_svd: int = 32, use_cholesky: bool = False, power_iter: int = 0
|
1980
|
+
) -> Tensor:
|
1981
|
+
if min(A.shape) <= max_svd:
|
1982
|
+
return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
|
1983
|
+
if use_cholesky or power_iter < 0:
|
1984
|
+
return max_singular_value_cholesky(A, max_abs)
|
1985
|
+
return max_singular_value_power_iter(A, None, iterations=power_iter)
|
1986
|
+
|
1987
|
+
|
1988
|
+
@decorator_knowngood
|
1989
|
+
def _psgd_default_preconditioner_grad(
|
1990
|
+
terms: List[Tuple[Tensor, Tensor]],
|
1991
|
+
Q: List[Tensor],
|
1992
|
+
) -> List[Tensor]:
|
1993
|
+
out = []
|
1994
|
+
for q, (x, y) in zip(Q, terms):
|
1995
|
+
x = promote(x)
|
1996
|
+
y = promote(y)
|
1997
|
+
update = x - y
|
1998
|
+
if q.ndim < 2:
|
1999
|
+
update = q * update
|
2000
|
+
else:
|
2001
|
+
update = (q @ update).triu()
|
2002
|
+
out.append(update)
|
2003
|
+
return out
|
2004
|
+
|
2005
|
+
|
2006
|
+
@decorator_knowngood
|
2007
|
+
def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
|
2008
|
+
if isinstance(Q[0], tuple):
|
2009
|
+
psgd_balance_Q([o[1] for o in Q])
|
2010
|
+
return line_to_triu(Q, symmetric_output)
|
2011
|
+
psgd_balance_Q(Q)
|
2012
|
+
return Q
|
2013
|
+
|
2014
|
+
|
2015
|
+
@functools.lru_cache(maxsize=None)
|
2016
|
+
def calcG_expr(q_dim, g_dim):
|
2017
|
+
exprs = []
|
2018
|
+
base = einsum_base[:g_dim]
|
2019
|
+
for i, q in enumerate(q_dim):
|
2020
|
+
new = list(base)
|
2021
|
+
if q == 2:
|
2022
|
+
new[i] = "Z"
|
2023
|
+
out = f"{base[i]}Z"
|
2024
|
+
else:
|
2025
|
+
out = base[i]
|
2026
|
+
exprs.append(f"{base},{''.join(new)}->{out}")
|
2027
|
+
return exprs
|
1672
2028
|
|
1673
2029
|
|
1674
2030
|
@decorator
|
1675
|
-
def psgd_update_precond(
|
2031
|
+
def psgd_update_precond(
|
2032
|
+
G: Tensor,
|
2033
|
+
precond_lr: float,
|
2034
|
+
oq: "TriuOrLine",
|
2035
|
+
store_triu_as_line: bool,
|
2036
|
+
velocity: Optional[List[Tensor]],
|
2037
|
+
beta2: float,
|
2038
|
+
ortho_method: Optional[str],
|
2039
|
+
V: Tensor,
|
2040
|
+
running_lower_bound: List[Tensor],
|
2041
|
+
lower_bount_beta: float,
|
2042
|
+
power_iter: int,
|
2043
|
+
) -> None:
|
1676
2044
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1677
|
-
|
1678
|
-
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1684
|
-
|
1685
|
-
|
1686
|
-
|
1687
|
-
|
2045
|
+
Q = _balance_to_triu(oq)
|
2046
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2047
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2048
|
+
|
2049
|
+
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
|
2050
|
+
terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
|
2051
|
+
del A, conjB, V
|
2052
|
+
updates = _psgd_default_preconditioner_grad(terms, Q)
|
2053
|
+
_psgd_precond_update_(
|
2054
|
+
updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2055
|
+
)
|
2056
|
+
return None
|
2057
|
+
|
2058
|
+
|
2059
|
+
@decorator_knowngood
|
2060
|
+
def _psgd_precond_update_(
|
2061
|
+
matmuled: List[Optional[Tensor]],
|
2062
|
+
Q: "TriuOrLine",
|
2063
|
+
running_lower_bound: List[Tensor],
|
2064
|
+
lower_bount_beta: Tensor,
|
2065
|
+
precond_lr: Tensor,
|
2066
|
+
store_triu_as_line: bool,
|
2067
|
+
power_iter: int,
|
2068
|
+
):
|
2069
|
+
for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
|
2070
|
+
if isinstance(oq, tuple):
|
2071
|
+
oq = oq[1]
|
2072
|
+
|
2073
|
+
q = promote(oq)
|
2074
|
+
if update.ndim < 2:
|
2075
|
+
lb = update.norm(float("inf"))
|
1688
2076
|
else:
|
1689
|
-
|
1690
|
-
|
1691
|
-
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
2077
|
+
lb = max_singular_value(update, None, power_iter=power_iter)
|
2078
|
+
update = promote(update)
|
2079
|
+
if store_triu_as_line:
|
2080
|
+
update = triu_to_line([update])[0][1]
|
2081
|
+
|
2082
|
+
lb = promote(lb)
|
2083
|
+
lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
|
2084
|
+
copy_stochastic_(lb_state, lb)
|
2085
|
+
copy_stochastic_(oq, q - update / lb * precond_lr)
|
2086
|
+
|
2087
|
+
|
2088
|
+
@decorator_knowngood
|
2089
|
+
def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
|
2090
|
+
"""
|
2091
|
+
I: Identity
|
2092
|
+
U: Update / gg / target
|
2093
|
+
Q: q, preconditioner
|
2094
|
+
scale: scalar scale
|
2095
|
+
---
|
2096
|
+
U = T * scale - I
|
2097
|
+
F = I - U # = 2I - U * scale
|
2098
|
+
O = F @ Q @ F - Q
|
2099
|
+
"""
|
2100
|
+
out = []
|
2101
|
+
for gg, q in zip(GG, Q):
|
2102
|
+
if gg.ndim < 2:
|
2103
|
+
scale = max(1, gg.numel()) / numel
|
2104
|
+
target = promote(gg)
|
2105
|
+
update = target * scale - 1
|
2106
|
+
out.append(q - (1 - update) * q * (1 - update))
|
1696
2107
|
else:
|
1697
|
-
|
1698
|
-
|
2108
|
+
scale = gg.size(0) / numel
|
2109
|
+
gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
|
2110
|
+
update = q - gg @ q @ gg
|
2111
|
+
out.append(update + update.T) # make matrix symmetric
|
2112
|
+
return out
|
2113
|
+
|
2114
|
+
|
2115
|
+
@decorator
|
2116
|
+
def inverse_free_psgd_update_precond(
|
2117
|
+
G: Tensor,
|
2118
|
+
precond_lr: float,
|
2119
|
+
oq: List[Tensor],
|
2120
|
+
store_triu_as_line: bool,
|
2121
|
+
velocity: Optional[List[Tensor]],
|
2122
|
+
beta2: float,
|
2123
|
+
ortho_method: Optional[str],
|
2124
|
+
V: None,
|
2125
|
+
running_lower_bound: List[Tensor],
|
2126
|
+
lower_bount_beta: float,
|
2127
|
+
power_iter: int,
|
2128
|
+
) -> Tensor:
|
2129
|
+
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
2130
|
+
assert V is None
|
2131
|
+
assert ortho_method is None
|
2132
|
+
assert velocity is None
|
2133
|
+
del V, ortho_method, velocity
|
2134
|
+
|
2135
|
+
Q = _balance_to_triu(oq, True)
|
2136
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2137
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2138
|
+
|
2139
|
+
G = psgd_precond_grad(G, Q)
|
2140
|
+
terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
|
2141
|
+
matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
|
2142
|
+
_psgd_precond_update_(
|
2143
|
+
matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2144
|
+
)
|
2145
|
+
return G
|
1699
2146
|
|
1700
2147
|
|
1701
2148
|
@decorator_knowngood
|
@@ -1732,6 +2179,34 @@ def rmsnorm_clip_(x, clip_at: float = 1.0):
|
|
1732
2179
|
return _compilable_rmsnorm_clip_(x, clip_at)
|
1733
2180
|
|
1734
2181
|
|
2182
|
+
@decorator_knowngood
|
2183
|
+
def _compilable_global_rmsnorm_clip_(x, clip_at):
|
2184
|
+
x = list(map(promote, x))
|
2185
|
+
norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
|
2186
|
+
norm = norm**0.5
|
2187
|
+
norm = norm.clamp(min=clip_at)
|
2188
|
+
return torch._foreach_div(x, norm)
|
2189
|
+
|
2190
|
+
|
2191
|
+
@decorator_knowngood
|
2192
|
+
def _compilable_global_l2norm_clip_(x, clip_at):
|
2193
|
+
x = list(map(promote, x))
|
2194
|
+
norm = sum([x.square().sum() for x in x])
|
2195
|
+
norm = norm**0.5
|
2196
|
+
norm = norm.clamp(min=clip_at)
|
2197
|
+
return torch._foreach_div(x, norm)
|
2198
|
+
|
2199
|
+
|
2200
|
+
def global_rmsnorm_clip(x, clip_at: float = 1.0):
|
2201
|
+
x = list_guard(x)
|
2202
|
+
return _compilable_global_rmsnorm_clip_(x, clip_at)
|
2203
|
+
|
2204
|
+
|
2205
|
+
def global_l2norm_clip(x, clip_at: float = 1.0):
|
2206
|
+
x = list_guard(x)
|
2207
|
+
return _compilable_global_rmsnorm_clip_(x, clip_at)
|
2208
|
+
|
2209
|
+
|
1735
2210
|
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1736
2211
|
x = list_guard(x)
|
1737
2212
|
return _compilable_rmsnorm_clip_(x, clip_at)
|
@@ -1809,6 +2284,17 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
|
1809
2284
|
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1810
2285
|
|
1811
2286
|
|
2287
|
+
@decorator_knowngood
|
2288
|
+
def _compilable_weight_decay_to_init_(p, init, weight_decay):
|
2289
|
+
_lerp(p, promote(init), 1 - weight_decay)
|
2290
|
+
|
2291
|
+
|
2292
|
+
def weight_decay_to_init_(p, init, weight_decay):
|
2293
|
+
p, init = list_guard(p, init)
|
2294
|
+
weight_decay = scalar_guard(weight_decay, p[0])
|
2295
|
+
_compilable_weight_decay_to_ema_(p, init, weight_decay)
|
2296
|
+
|
2297
|
+
|
1812
2298
|
@decorator_knowngood
|
1813
2299
|
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1814
2300
|
ema32 = _lerp(ema, p, ema_decay)
|
@@ -1867,35 +2353,25 @@ def triu_to_line(Q_list: List[Tensor]):
|
|
1867
2353
|
if q.dim() < 2:
|
1868
2354
|
out.append((None, q))
|
1869
2355
|
else:
|
1870
|
-
out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
|
2356
|
+
out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
|
1871
2357
|
return out
|
1872
2358
|
|
1873
2359
|
|
1874
|
-
|
1875
|
-
|
1876
|
-
assert n * (n + 1) == 2 * numel
|
1877
|
-
return n, n
|
1878
|
-
|
1879
|
-
|
1880
|
-
@decorator
|
1881
|
-
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
|
2360
|
+
@decorator_knowngood
|
2361
|
+
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
|
1882
2362
|
new = []
|
1883
2363
|
for shape, q in Q_list:
|
1884
2364
|
if shape is not None:
|
1885
|
-
|
1886
|
-
|
1887
|
-
x
|
1888
|
-
|
2365
|
+
x, y = torch.triu_indices(*shape, device=q.device)
|
2366
|
+
q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
|
2367
|
+
q_mat[x, y] = q
|
2368
|
+
if symmetric_output:
|
2369
|
+
q_mat[y, x] = q
|
2370
|
+
q = q_mat
|
1889
2371
|
new.append(q)
|
1890
2372
|
return new
|
1891
2373
|
|
1892
2374
|
|
1893
|
-
def update_triu_(q_state, materialised):
|
1894
|
-
for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
|
1895
|
-
assert shape0 == shape1
|
1896
|
-
copy_stochastic_(q, m)
|
1897
|
-
|
1898
|
-
|
1899
2375
|
_warned = set()
|
1900
2376
|
|
1901
2377
|
|
@@ -1918,52 +2394,118 @@ def psgd_should_update(
|
|
1918
2394
|
return int(group[name]) > int(cumulative_prob)
|
1919
2395
|
|
1920
2396
|
|
2397
|
+
@functools.lru_cache(maxsize=None)
|
2398
|
+
def cached_precond_grad_expr(Q_dim, grad_dim):
|
2399
|
+
expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
|
2400
|
+
expr = ",".join(expr)
|
2401
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2402
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2403
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
2404
|
+
|
2405
|
+
|
1921
2406
|
@decorator_knowngood
|
1922
2407
|
def precond_grad_cached_(
|
1923
|
-
|
2408
|
+
ea: Tensor,
|
2409
|
+
cached_q: List[Tensor],
|
2410
|
+
caution: bool = False,
|
2411
|
+
grad: Optional[Tensor] = None,
|
2412
|
+
cast: bool = True,
|
1924
2413
|
):
|
1925
2414
|
if caution:
|
1926
2415
|
ea = _compilable_cautioning(grad, ea)
|
1927
2416
|
md = min_dtype(list(cached_q) + [ea])
|
1928
2417
|
args = [q.to(md) for q in cached_q]
|
1929
2418
|
args = args + [ea.to(md)]
|
1930
|
-
|
2419
|
+
expr = cached_precond_grad_expr(ndim_tuple(cached_q), grad.ndim)
|
2420
|
+
new = compiled_einsum(expr, *args)
|
1931
2421
|
if cast:
|
1932
2422
|
return new.to(ea.dtype)
|
1933
2423
|
return new
|
1934
2424
|
|
1935
2425
|
|
2426
|
+
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
|
2427
|
+
|
2428
|
+
|
1936
2429
|
@decorator_knowngood
|
1937
|
-
def _compilable_fused_precond_grad_cached_(
|
1938
|
-
precond = precond_grad_cached_(
|
2430
|
+
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2431
|
+
precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
|
1939
2432
|
update_param_(param, precond, lr, decay, caution=False)
|
1940
2433
|
|
1941
2434
|
|
1942
|
-
def fused_precond_grad_cached_(
|
2435
|
+
def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
1943
2436
|
lr = scalar_guard(lr, param[0])
|
1944
|
-
_compilable_fused_precond_grad_cached_(
|
2437
|
+
_compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
|
2438
|
+
|
2439
|
+
|
2440
|
+
@functools.lru_cache(maxsize=None)
|
2441
|
+
def precond_grad_expr(Q_dim, grad_dim):
|
2442
|
+
expr = [
|
2443
|
+
f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
|
2444
|
+
]
|
2445
|
+
expr = ",".join(expr)
|
2446
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2447
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2448
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
1945
2449
|
|
1946
2450
|
|
1947
2451
|
@decorator_knowngood
|
1948
|
-
def psgd_precond_grad(
|
2452
|
+
def psgd_precond_grad(
|
2453
|
+
ea: Tensor,
|
2454
|
+
preconds: TriuOrLine,
|
2455
|
+
caution: bool = False,
|
2456
|
+
grad: Optional[Tensor] = None,
|
2457
|
+
store_triu_as_line: bool = False,
|
2458
|
+
symmetric_output: bool = False,
|
2459
|
+
):
|
1949
2460
|
if caution:
|
1950
2461
|
ea = _compilable_cautioning(grad, ea)
|
2462
|
+
if store_triu_as_line:
|
2463
|
+
preconds = line_to_triu(preconds, symmetric_output)
|
1951
2464
|
md = min_dtype(list(preconds) + [ea])
|
1952
2465
|
args = [q.to(md) for q in preconds]
|
1953
|
-
|
1954
|
-
new =
|
2466
|
+
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
|
2467
|
+
new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
|
1955
2468
|
return new.to(ea.dtype)
|
1956
2469
|
|
1957
2470
|
|
1958
2471
|
@decorator_knowngood
|
1959
|
-
def _compilable_fused_psgd_precond_grad(
|
1960
|
-
|
2472
|
+
def _compilable_fused_psgd_precond_grad(
|
2473
|
+
ea: Tensor,
|
2474
|
+
param,
|
2475
|
+
lr,
|
2476
|
+
grad,
|
2477
|
+
decay,
|
2478
|
+
caution,
|
2479
|
+
preconds: TriuOrLine,
|
2480
|
+
store_triu_as_line: bool = False,
|
2481
|
+
symmetric_output: bool = False,
|
2482
|
+
):
|
2483
|
+
precond = psgd_precond_grad(
|
2484
|
+
ea,
|
2485
|
+
preconds,
|
2486
|
+
caution=caution,
|
2487
|
+
grad=grad,
|
2488
|
+
store_triu_as_line=store_triu_as_line,
|
2489
|
+
symmetric_output=symmetric_output,
|
2490
|
+
)
|
1961
2491
|
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
1962
2492
|
|
1963
2493
|
|
1964
|
-
def fused_psgd_precond_grad(
|
2494
|
+
def fused_psgd_precond_grad(
|
2495
|
+
ea: Tensor,
|
2496
|
+
param,
|
2497
|
+
lr,
|
2498
|
+
grad,
|
2499
|
+
decay,
|
2500
|
+
caution,
|
2501
|
+
preconds: TriuOrLine,
|
2502
|
+
store_triu_as_line: bool = False,
|
2503
|
+
symmetric_output: bool = False,
|
2504
|
+
):
|
1965
2505
|
lr = scalar_guard(lr, param[0])
|
1966
|
-
_compilable_fused_psgd_precond_grad(
|
2506
|
+
_compilable_fused_psgd_precond_grad(
|
2507
|
+
ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
|
2508
|
+
)
|
1967
2509
|
|
1968
2510
|
|
1969
2511
|
@decorator_knowngood
|
@@ -2015,7 +2557,15 @@ def caution(g, update):
|
|
2015
2557
|
return _compilable_cautioning(g, update)
|
2016
2558
|
|
2017
2559
|
|
2018
|
-
def
|
2560
|
+
def _inner_precond_update_prob_schedule(
|
2561
|
+
n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2562
|
+
):
|
2563
|
+
return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
|
2564
|
+
|
2565
|
+
|
2566
|
+
def precond_update_prob_schedule(
|
2567
|
+
max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2568
|
+
):
|
2019
2569
|
"""Anneal preconditioner update probability during beginning of training.
|
2020
2570
|
|
2021
2571
|
PSGD benefits from more preconditioner updates at the beginning of training,
|
@@ -2026,11 +2576,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
|
|
2026
2576
|
`min_prob` by ~4000 steps. Default settings work very well for most models and
|
2027
2577
|
training regimes.
|
2028
2578
|
"""
|
2029
|
-
|
2030
|
-
|
2031
|
-
|
2032
|
-
|
2033
|
-
return _schedule
|
2579
|
+
return functools.partial(
|
2580
|
+
_inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
|
2581
|
+
)
|
2034
2582
|
|
2035
2583
|
|
2036
2584
|
def merge_group(group, *tensors):
|
@@ -2164,3 +2712,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
|
2164
2712
|
def disable_caution_scaling():
|
2165
2713
|
global _compilable_cautioning
|
2166
2714
|
_compilable_cautioning = _compilable_caution_no_scale
|
2715
|
+
|
2716
|
+
|
2717
|
+
@decorator_knowngood
|
2718
|
+
def sam_step(parameters, ball_size, adaptive: bool = True):
|
2719
|
+
old_params = []
|
2720
|
+
for p in parameters:
|
2721
|
+
old_params.append(p.detach().clone())
|
2722
|
+
grad = promote(p.grad)
|
2723
|
+
if adaptive:
|
2724
|
+
grad = grad * promote(p).square()
|
2725
|
+
stochastic_add_(p.data, grad, ball_size)
|
2726
|
+
p.grad.zero_()
|
2727
|
+
return old_params
|