heavyball 1.7.2__py3-none-any.whl → 2.0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -16
- heavyball/chainable.py +338 -190
- heavyball/helpers.py +804 -0
- heavyball/utils.py +770 -262
- heavyball-2.0.0.dev0.dist-info/METADATA +109 -0
- heavyball-2.0.0.dev0.dist-info/RECORD +9 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dev0.dist-info}/WHEEL +1 -1
- heavyball-1.7.2.dist-info/METADATA +0 -939
- heavyball-1.7.2.dist-info/RECORD +0 -8
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dev0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dev0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,28 +1,28 @@
|
|
1
|
+
import collections
|
1
2
|
import contextlib
|
2
3
|
import functools
|
3
4
|
import gc
|
4
5
|
import inspect
|
5
6
|
import math
|
7
|
+
import pickle
|
6
8
|
import random
|
7
9
|
import re
|
8
10
|
import string
|
9
11
|
import warnings
|
10
|
-
from typing import Callable, List, Optional, Tuple, Union
|
12
|
+
from typing import Callable, List, Literal, Optional, Tuple, Union
|
11
13
|
|
12
14
|
import numpy as np
|
13
15
|
import torch
|
14
16
|
from torch import Tensor
|
15
|
-
from torch._dynamo import config
|
16
17
|
from torch._dynamo.exc import TorchDynamoException
|
17
18
|
from torch.backends import cudnn, opt_einsum
|
18
19
|
from torch.utils._pytree import tree_map
|
19
20
|
|
20
|
-
config.cache_size_limit = 2**16
|
21
|
-
|
22
21
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
22
|
dynamic = False
|
24
23
|
compile_mode_recommended_to_none = None
|
25
|
-
zeroth_power_mode = "
|
24
|
+
zeroth_power_mode = "newtonschulz"
|
25
|
+
precise_zeroth_power_mode = "qr" # or svd
|
26
26
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
27
|
_cudnn_double_backward_pattern = re.compile(
|
28
28
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -68,6 +68,16 @@ def decorator_knowngood(func: Callable, fullgraph: bool = True):
|
|
68
68
|
einsum_base = string.ascii_lowercase
|
69
69
|
|
70
70
|
|
71
|
+
@decorator_knowngood
|
72
|
+
def compiled_einsum(expr, *args):
|
73
|
+
"""
|
74
|
+
this is necessary to avoid the slowdown introduced by uncompiled einsum
|
75
|
+
uncompiled einsum is twice as slow if we add three 1-sized dimensions
|
76
|
+
for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
|
77
|
+
"""
|
78
|
+
return torch.einsum(expr, *args)
|
79
|
+
|
80
|
+
|
71
81
|
@decorator_knowngood
|
72
82
|
def _compilable_schedule_free_(
|
73
83
|
p: List[Tensor],
|
@@ -122,6 +132,47 @@ def schedule_free_(
|
|
122
132
|
return weight_sum
|
123
133
|
|
124
134
|
|
135
|
+
@decorator_knowngood
|
136
|
+
def _compilable_msam(
|
137
|
+
lr: Tensor,
|
138
|
+
beta1: Tensor,
|
139
|
+
param: List[Tensor],
|
140
|
+
z: List[Tensor],
|
141
|
+
update: List[Tensor],
|
142
|
+
grad: List[Tensor],
|
143
|
+
exp_avg: List[Tensor],
|
144
|
+
caution: bool,
|
145
|
+
decay: Tensor,
|
146
|
+
sam_step_size: Tensor,
|
147
|
+
):
|
148
|
+
exp_avg32 = _lerp(exp_avg, update, beta1)
|
149
|
+
for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
|
150
|
+
u_ = u_.view_as(z_)
|
151
|
+
z32_ = promote(z_)
|
152
|
+
if caution:
|
153
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
154
|
+
z32_ = z32_ * (1 - decay * lr) + u_ * -lr
|
155
|
+
copy_stochastic_(z_, z32_)
|
156
|
+
copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
|
157
|
+
|
158
|
+
|
159
|
+
def msam_(
|
160
|
+
lr: float,
|
161
|
+
beta1: float,
|
162
|
+
param: List[Tensor],
|
163
|
+
z: List[Tensor],
|
164
|
+
update: List[Tensor],
|
165
|
+
grad: List[Tensor],
|
166
|
+
exp_avg: List[Tensor],
|
167
|
+
caution: bool,
|
168
|
+
weight_decay: float,
|
169
|
+
sam_step_size: float,
|
170
|
+
):
|
171
|
+
param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
|
172
|
+
lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
|
173
|
+
_compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
|
174
|
+
|
175
|
+
|
125
176
|
def append_or_extend(base, new):
|
126
177
|
if isinstance(new, list):
|
127
178
|
base.extend(new)
|
@@ -161,7 +212,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
161
212
|
new_shape = [grad.shape[0], *new_shape[::-1]]
|
162
213
|
new_grad = grad.reshape(new_shape)
|
163
214
|
if not split:
|
164
|
-
return new_grad
|
215
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
165
216
|
|
166
217
|
grads = [new_grad]
|
167
218
|
for i, sh in reversed(list(enumerate(new_shape[:]))):
|
@@ -172,7 +223,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
172
223
|
continue
|
173
224
|
grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
|
174
225
|
if len(grads) == 1:
|
175
|
-
return new_grad
|
226
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
176
227
|
new_grads = []
|
177
228
|
for g in grads:
|
178
229
|
append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
|
@@ -279,16 +330,29 @@ def clean():
|
|
279
330
|
|
280
331
|
|
281
332
|
def _ignore_warning(msg):
|
282
|
-
warnings.filterwarnings("ignore", f".*{msg}.*")
|
333
|
+
warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
|
283
334
|
|
284
335
|
|
285
|
-
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
336
|
+
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
|
337
|
+
import opt_einsum as _opt_einsum
|
338
|
+
|
286
339
|
cudnn.benchmark = True
|
287
340
|
cudnn.deterministic = False
|
288
341
|
cudnn.benchmark_limit = benchmark_limit
|
289
342
|
torch.use_deterministic_algorithms(False)
|
290
343
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
291
|
-
opt_einsum.set_flags(True
|
344
|
+
opt_einsum.set_flags(True)
|
345
|
+
if einsum_strategy == "heavyball":
|
346
|
+
opt_einsum.strategy = "auto-hq"
|
347
|
+
choices = _opt_einsum.paths._AUTO_HQ_CHOICES
|
348
|
+
for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
|
349
|
+
if isinstance(fn, int):
|
350
|
+
fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
|
351
|
+
for i in range(max(choices.keys()), max_val):
|
352
|
+
if i not in choices:
|
353
|
+
choices[i] = fn
|
354
|
+
else:
|
355
|
+
opt_einsum.strategy = einsum_strategy
|
292
356
|
|
293
357
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
294
358
|
_ignore_warning(
|
@@ -297,6 +361,9 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
|
297
361
|
_ignore_warning(
|
298
362
|
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
363
|
)
|
364
|
+
_ignore_warning(
|
365
|
+
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
|
366
|
+
)
|
300
367
|
|
301
368
|
|
302
369
|
@decorator
|
@@ -316,15 +383,6 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
316
383
|
return X.to(G.dtype)
|
317
384
|
|
318
385
|
|
319
|
-
def ortho(x):
|
320
|
-
if zeroth_power_mode == "qr":
|
321
|
-
return torch.linalg.qr(x).Q
|
322
|
-
if zeroth_power_mode == "svd":
|
323
|
-
u, _s, v = torch.linalg.svd(x)
|
324
|
-
return u @ v.T
|
325
|
-
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
326
|
-
|
327
|
-
|
328
386
|
@decorator_knowngood
|
329
387
|
def _compilable_heavyball_momentum_(state, grad, beta):
|
330
388
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
@@ -377,7 +435,7 @@ def _compilable_grafting(magnitude, direction):
|
|
377
435
|
|
378
436
|
|
379
437
|
@decorator_knowngood
|
380
|
-
def
|
438
|
+
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
381
439
|
if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
|
382
440
|
y = zeropower_via_newtonschulz5(x, 5)
|
383
441
|
elif mode == "qr":
|
@@ -395,9 +453,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
|
395
453
|
y = _compilable_grafting(x, y)
|
396
454
|
else:
|
397
455
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
456
|
+
if out is None:
|
457
|
+
return y
|
458
|
+
|
398
459
|
set_(out, y)
|
399
460
|
|
400
461
|
|
462
|
+
def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
|
463
|
+
return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
|
464
|
+
|
465
|
+
|
401
466
|
@decorator_knowngood
|
402
467
|
def _compilable_scatter_set(target, source, index):
|
403
468
|
target[:] = source.contiguous()[index].reshape_as(target)
|
@@ -413,6 +478,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
413
478
|
:param Q: List of current eigenbases (updated in-place to Q_new).
|
414
479
|
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
415
480
|
"""
|
481
|
+
if exp_avg.dim() == 0: # preconditioning doesn't make sense here
|
482
|
+
Q.clear()
|
483
|
+
return
|
484
|
+
|
416
485
|
if isinstance(Q, list) and not Q:
|
417
486
|
return
|
418
487
|
|
@@ -430,10 +499,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
430
499
|
q_old = promote(q.data)
|
431
500
|
|
432
501
|
tmp = m @ q_old
|
433
|
-
est_eig =
|
502
|
+
est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
|
434
503
|
sort_idx = torch.argsort(est_eig, descending=True)
|
435
504
|
|
436
|
-
tmp[:, sort_idx]
|
505
|
+
tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
|
437
506
|
new_qs.append(tmp)
|
438
507
|
|
439
508
|
if exp_avg is None:
|
@@ -453,7 +522,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
453
522
|
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
454
523
|
|
455
524
|
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
456
|
-
exp_avg_new =
|
525
|
+
exp_avg_new = compiled_einsum(
|
457
526
|
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
527
|
)
|
459
528
|
copy_stochastic_(exp_avg, exp_avg_new)
|
@@ -568,6 +637,19 @@ def scalar_guard(*args):
|
|
568
637
|
return out
|
569
638
|
|
570
639
|
|
640
|
+
def broadcastable_list_guard(*xs):
|
641
|
+
xs = list_guard(*xs)
|
642
|
+
for x in xs:
|
643
|
+
if isinstance(x[0], Tensor):
|
644
|
+
ref = x[0]
|
645
|
+
break
|
646
|
+
else:
|
647
|
+
raise ValueError("No tensor-valued input given")
|
648
|
+
xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
|
649
|
+
max_len = max(len(x) for x in xs)
|
650
|
+
return [x if len(x) > 1 else x * max_len for x in xs]
|
651
|
+
|
652
|
+
|
571
653
|
@decorator_knowngood
|
572
654
|
def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
573
655
|
for x_, y_ in zip(x, y):
|
@@ -577,7 +659,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
577
659
|
|
578
660
|
|
579
661
|
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
580
|
-
x, y =
|
662
|
+
x, y = broadcastable_list_guard(x, y)
|
581
663
|
alpha = scalar_guard(alpha, x[0])
|
582
664
|
_compilable_stochastic_add_(x, y, alpha)
|
583
665
|
|
@@ -591,7 +673,7 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
|
|
591
673
|
|
592
674
|
|
593
675
|
def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
|
594
|
-
x, y =
|
676
|
+
x, y = broadcastable_list_guard(x, y)
|
595
677
|
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
678
|
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
679
|
|
@@ -605,7 +687,7 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
605
687
|
|
606
688
|
|
607
689
|
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
608
|
-
x, y =
|
690
|
+
x, y = broadcastable_list_guard(x, y)
|
609
691
|
_compilable_stochastic_multiply_(x, y)
|
610
692
|
|
611
693
|
|
@@ -624,7 +706,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
624
706
|
b = einsum_base[idx]
|
625
707
|
g0 = einsum_base[: grad.dim()]
|
626
708
|
g1 = g0.replace(b, b.upper())
|
627
|
-
outer_product =
|
709
|
+
outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
|
628
710
|
stochastic_lerp_(m, outer_product, 1 - beta)
|
629
711
|
|
630
712
|
|
@@ -706,7 +788,7 @@ def project(grad, Q, back: bool):
|
|
706
788
|
preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
707
789
|
if preconditioners:
|
708
790
|
out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
|
709
|
-
out =
|
791
|
+
out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
|
710
792
|
grad = out.to(grad.dtype)
|
711
793
|
return grad
|
712
794
|
|
@@ -714,24 +796,28 @@ def project(grad, Q, back: bool):
|
|
714
796
|
@contextlib.contextmanager
|
715
797
|
def patch_backward():
|
716
798
|
@contextlib.contextmanager
|
717
|
-
def
|
799
|
+
def patch_module(module):
|
718
800
|
original = module.backward
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
801
|
+
try:
|
802
|
+
signature = inspect.signature(original)
|
803
|
+
|
804
|
+
@functools.wraps(original)
|
805
|
+
def patched_backward(*args, **kwargs):
|
806
|
+
new_kwargs = signature.bind(*args)
|
807
|
+
new_kwargs.apply_defaults()
|
808
|
+
new_kwargs = new_kwargs.arguments
|
809
|
+
new_kwargs.update(kwargs)
|
810
|
+
new_kwargs["create_graph"] = True
|
811
|
+
return original(**new_kwargs)
|
812
|
+
|
813
|
+
module.backward = patched_backward
|
814
|
+
yield
|
815
|
+
finally:
|
816
|
+
module.backward = original
|
817
|
+
|
818
|
+
with contextlib.ExitStack() as stack:
|
819
|
+
stack.enter_context(patch_module(torch.Tensor))
|
820
|
+
stack.enter_context(patch_module(torch.autograd))
|
735
821
|
yield
|
736
822
|
|
737
823
|
|
@@ -743,6 +829,9 @@ class ExactHVPFailed(ValueError):
|
|
743
829
|
pass
|
744
830
|
|
745
831
|
|
832
|
+
use_default = object()
|
833
|
+
|
834
|
+
|
746
835
|
class StatefulOptimizer(torch.optim.Optimizer):
|
747
836
|
"""
|
748
837
|
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
@@ -755,7 +844,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
755
844
|
compile_step: bool = False
|
756
845
|
hessian_approx: bool = False
|
757
846
|
precond_schedule: Union[Callable, float, None] = None
|
758
|
-
stochastic_schedule: bool =
|
847
|
+
stochastic_schedule: bool | Literal[use_default] = use_default
|
759
848
|
finite_differences: bool = False
|
760
849
|
fallback_to_finite_differences: bool = True
|
761
850
|
_fallback_enabled: bool = False
|
@@ -765,18 +854,61 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
765
854
|
super().__init__(params, {**defaults, "foreach": foreach})
|
766
855
|
self.use_ema = use_ema
|
767
856
|
self.mapping = {}
|
768
|
-
self.
|
769
|
-
|
857
|
+
self.mapping_inverse = {}
|
858
|
+
|
859
|
+
if self.stochastic_schedule is use_default:
|
860
|
+
stochastic_schedule = None
|
861
|
+
for group in self.param_groups:
|
862
|
+
new = group.get("stochastic_schedule", stochastic_schedule)
|
863
|
+
if stochastic_schedule is not None and new != stochastic_schedule:
|
864
|
+
raise ValueError("All parameter groups must have the same stochastic_schedule.")
|
865
|
+
stochastic_schedule = new
|
866
|
+
self.stochastic_schedule = stochastic_schedule
|
867
|
+
|
868
|
+
self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
|
869
|
+
self.precond_rng = random.Random(0x12312)
|
770
870
|
self._is_preconditioning = None
|
771
871
|
|
772
872
|
if self.hessian_approx and self.compile_step:
|
773
873
|
raise ValueError("Hessian approximation can't be used with compile_step.")
|
774
874
|
|
875
|
+
self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
|
876
|
+
self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
|
877
|
+
self._init_mapping()
|
878
|
+
|
879
|
+
def _store_stats(self, state_dict: dict[str, any]):
|
880
|
+
state_dict["heavyball"] = {
|
881
|
+
"inner_group": self.inner_group,
|
882
|
+
"precond_rng": pickle.dumps(self.precond_rng),
|
883
|
+
"use_ema": self.use_ema,
|
884
|
+
"ema_decay": self.ema_decay,
|
885
|
+
"compile_step": self.compile_step,
|
886
|
+
"hessian_approx": self.hessian_approx,
|
887
|
+
"precond_schedule": pickle.dumps(self.precond_schedule),
|
888
|
+
"stochastic_schedule": self.stochastic_schedule,
|
889
|
+
"fallback_to_finite_differences": self.fallback_to_finite_differences,
|
890
|
+
"_fallback_enabled": self._fallback_enabled,
|
891
|
+
"hvp_interval": self.hvp_interval,
|
892
|
+
}
|
893
|
+
|
894
|
+
def _load_stats(self, state_dict):
|
895
|
+
sd = state_dict.pop("heavyball", {})
|
896
|
+
for k, v in sd.items():
|
897
|
+
if k in ("precond_rng", "precond_schedule"):
|
898
|
+
v = pickle.loads(v)
|
899
|
+
setattr(self, k, v)
|
900
|
+
|
775
901
|
def get_groups(self, group):
|
776
902
|
return [group]
|
777
903
|
|
778
|
-
|
779
|
-
|
904
|
+
@functools.lru_cache(maxsize=None)
|
905
|
+
def state_(self, arg: Tensor, fail: bool = True):
|
906
|
+
if not fail and arg not in self.mapping:
|
907
|
+
return {}
|
908
|
+
state_param, index = self.mapping_inverse[arg]
|
909
|
+
if state_param not in self.state:
|
910
|
+
self.state[state_param] = collections.defaultdict(dict)
|
911
|
+
return self.state[state_param][index]
|
780
912
|
|
781
913
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
782
914
|
for p, g in zip(p_list, g_list):
|
@@ -786,6 +918,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
786
918
|
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
787
919
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
788
920
|
|
921
|
+
def _init_mapping(self, group: dict | None = None):
|
922
|
+
if group is None:
|
923
|
+
for group in self.param_groups:
|
924
|
+
self._init_mapping(group)
|
925
|
+
return
|
926
|
+
|
927
|
+
for p in group["params"]:
|
928
|
+
if p not in self.mapping:
|
929
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
930
|
+
for i, pv in enumerate(p_views):
|
931
|
+
self.mapping_inverse[pv] = (p, i)
|
932
|
+
|
789
933
|
def split_p_and_g_in_group(
|
790
934
|
self,
|
791
935
|
group: dict,
|
@@ -809,6 +953,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
809
953
|
p_views = self.mapping[p]
|
810
954
|
else:
|
811
955
|
self.mapping[p] = p_views = merge_group(group, p)
|
956
|
+
for i, pv in enumerate(p_views):
|
957
|
+
self.mapping_inverse[pv] = (p, i)
|
812
958
|
|
813
959
|
vector = getattr(p, "vector", None)
|
814
960
|
hessian_vector = getattr(p, "hessian_vector", None)
|
@@ -957,8 +1103,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
957
1103
|
raise ValueError("Hessian approximation requires a closure.")
|
958
1104
|
return None
|
959
1105
|
|
960
|
-
step = self.
|
961
|
-
if not hessian_approx or step % self.hvp_interval == 0:
|
1106
|
+
step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
|
1107
|
+
if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
|
962
1108
|
with torch.enable_grad():
|
963
1109
|
loss = closure()
|
964
1110
|
return loss
|
@@ -997,12 +1143,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
997
1143
|
if self.precond_schedule is None:
|
998
1144
|
self._is_preconditioning = False
|
999
1145
|
else:
|
1000
|
-
self._is_preconditioning = psgd_should_update(self.
|
1146
|
+
self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
|
1001
1147
|
loss = self._handle_closure(closure)
|
1002
1148
|
|
1003
1149
|
# we assume that parameters are constant and that there are no excessive recompiles
|
1004
1150
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
1005
1151
|
for group in self.param_groups:
|
1152
|
+
if "param_count" not in group:
|
1153
|
+
group["param_count"] = sum(p.numel() for p in group["params"])
|
1006
1154
|
group["is_preconditioning"] = self._is_preconditioning
|
1007
1155
|
self._step(group)
|
1008
1156
|
if self.use_ema:
|
@@ -1306,83 +1454,115 @@ def stable_exp(x: Tensor):
|
|
1306
1454
|
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1307
1455
|
|
1308
1456
|
|
1457
|
+
def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
|
1458
|
+
# ln(mean(x ** pow) ** (1 / pow / 2))
|
1459
|
+
normalization = math.log(x.numel())
|
1460
|
+
x = x.double()
|
1461
|
+
x = x.abs()
|
1462
|
+
x = x.clamp(min=eps)
|
1463
|
+
x = x.log()
|
1464
|
+
x = x * pow
|
1465
|
+
x = x.flatten()
|
1466
|
+
x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
|
1467
|
+
x = x - normalization # sum -> mean (divide by x.numel() in log space)
|
1468
|
+
return x / pow / 2
|
1469
|
+
|
1470
|
+
|
1309
1471
|
@decorator_knowngood
|
1310
1472
|
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
1473
|
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
-
|
1313
|
-
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
-
return stable_exp(-log_mean_x_pow / pow / 2)
|
1474
|
+
return stable_exp(-_lse_mean(x, pow, eps))
|
1315
1475
|
|
1316
1476
|
|
1317
1477
|
@decorator_knowngood
|
1318
1478
|
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
1479
|
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
-
|
1321
|
-
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
-
|
1323
|
-
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
-
x_normed = x_normed / pow0 / 2
|
1480
|
+
return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
|
1325
1481
|
|
1326
|
-
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
-
y_normed = y_normed / pow1 / 2
|
1328
1482
|
|
1329
|
-
|
1483
|
+
class PrecondInitError(ValueError):
|
1484
|
+
pass
|
1330
1485
|
|
1331
1486
|
|
1332
|
-
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float =
|
1487
|
+
def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
|
1333
1488
|
automatic_scale = True
|
1334
1489
|
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1490
|
+
scale_scale = 1 if scale_scale is None else scale_scale
|
1335
1491
|
|
1336
1492
|
if scale is not None:
|
1337
1493
|
automatic_scale = False
|
1338
1494
|
warn_once(
|
1339
1495
|
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1340
1496
|
)
|
1341
|
-
if scale_scale
|
1497
|
+
if scale_scale != 1:
|
1342
1498
|
warn_once(
|
1343
|
-
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly
|
1499
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
|
1500
|
+
)
|
1501
|
+
if scale_power is not None:
|
1502
|
+
warn_once(
|
1503
|
+
"precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
|
1344
1504
|
)
|
1345
1505
|
elif hessian_vector is None:
|
1346
1506
|
scale = mean_root(grad, 4) * scale_scale
|
1347
1507
|
else:
|
1348
1508
|
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1349
1509
|
|
1510
|
+
if automatic_scale:
|
1511
|
+
scale_power = 0.5 if scale_power is None else scale_power
|
1512
|
+
scale = scale**scale_power
|
1513
|
+
|
1350
1514
|
if isinstance(scale, torch.Tensor):
|
1351
1515
|
scale = scale.item() # slow, but necessary
|
1352
1516
|
|
1353
1517
|
if np.isfinite(scale):
|
1354
|
-
if scale > scale_max
|
1518
|
+
if scale > scale_max: # fallthrough to later checks
|
1355
1519
|
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1356
1520
|
else:
|
1357
1521
|
return scale
|
1358
1522
|
|
1359
1523
|
if not automatic_scale:
|
1360
|
-
raise
|
1524
|
+
raise PrecondInitError("The manually set precond_init_scale is not finite")
|
1361
1525
|
|
1362
1526
|
for x in (grad, hessian_vector, vector):
|
1363
1527
|
if x is None:
|
1364
1528
|
continue
|
1365
|
-
if torch.allclose(x, torch.zeros_like(x))
|
1366
|
-
raise
|
1529
|
+
if torch.allclose(x, torch.zeros_like(x)):
|
1530
|
+
raise PrecondInitError(
|
1531
|
+
f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
|
1532
|
+
)
|
1367
1533
|
if not torch.isfinite(x).all().item():
|
1368
|
-
raise
|
1534
|
+
raise PrecondInitError("Grad or HVP is not finite")
|
1369
1535
|
|
1370
1536
|
if np.isfinite(scale):
|
1371
1537
|
return scale
|
1372
1538
|
|
1373
|
-
raise
|
1539
|
+
raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1374
1540
|
|
1375
1541
|
|
1376
|
-
def init_lra(
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1542
|
+
def init_lra(
|
1543
|
+
grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
|
1544
|
+
):
|
1545
|
+
# "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
|
1546
|
+
# https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
|
1547
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1548
|
+
uv_scale = (param_count * (rank + eps)) ** -0.5
|
1549
|
+
U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1550
|
+
V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1380
1551
|
d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
|
1381
1552
|
return U, V, d
|
1382
1553
|
|
1383
1554
|
|
1384
1555
|
def init_Q_exprs(
|
1385
|
-
grad,
|
1556
|
+
grad,
|
1557
|
+
scale,
|
1558
|
+
scale_scale,
|
1559
|
+
scale_power,
|
1560
|
+
max_size,
|
1561
|
+
min_ndim_triangular,
|
1562
|
+
memory_save_mode,
|
1563
|
+
hessian_vector,
|
1564
|
+
vector,
|
1565
|
+
dtype=None,
|
1386
1566
|
):
|
1387
1567
|
"""
|
1388
1568
|
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
@@ -1391,21 +1571,13 @@ def init_Q_exprs(
|
|
1391
1571
|
precond init scale computation from
|
1392
1572
|
https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
|
1393
1573
|
"""
|
1394
|
-
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1395
|
-
letters = string.ascii_lowercase + string.ascii_uppercase
|
1574
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1396
1575
|
dtype = dtype if dtype is not None else grad.dtype
|
1397
1576
|
shape = grad.shape
|
1398
1577
|
|
1399
1578
|
if len(shape) == 0: # scalar
|
1400
1579
|
Q = [scale * torch.ones_like(grad, dtype=dtype)]
|
1401
|
-
|
1402
|
-
exprGs = [",->"]
|
1403
|
-
exprP = ",,->"
|
1404
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1405
|
-
|
1406
|
-
# Tensor
|
1407
|
-
if len(shape) > 13:
|
1408
|
-
raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
|
1580
|
+
return Q
|
1409
1581
|
|
1410
1582
|
scale = scale ** (1 / len(shape))
|
1411
1583
|
|
@@ -1418,6 +1590,9 @@ def init_Q_exprs(
|
|
1418
1590
|
sorted_shape = sorted(shape)
|
1419
1591
|
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1420
1592
|
dim_diag[_max_idx(shape)] = True
|
1593
|
+
elif memory_save_mode == "one_triu":
|
1594
|
+
shape_ranks = np.argsort(np.argsort(shape)) # ranks
|
1595
|
+
dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
|
1421
1596
|
elif memory_save_mode == "all_diag":
|
1422
1597
|
dim_diag = [True for _ in shape]
|
1423
1598
|
else:
|
@@ -1427,66 +1602,90 @@ def init_Q_exprs(
|
|
1427
1602
|
)
|
1428
1603
|
|
1429
1604
|
Q = []
|
1430
|
-
piece1A, piece2A, piece3A = ([], "", "")
|
1431
|
-
exprGs = []
|
1432
|
-
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
|
1433
1605
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
1434
1606
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
1435
1607
|
# use diagonal matrix as preconditioner for this dim
|
1436
1608
|
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
|
1437
|
-
|
1438
|
-
piece1A.append(letters[i])
|
1439
|
-
piece2A = piece2A + letters[i]
|
1440
|
-
piece3A = piece3A + letters[i]
|
1441
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1442
|
-
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1443
|
-
exprGs.append(subscripts)
|
1444
|
-
piece1P.append(letters[i + 13])
|
1445
|
-
piece2P.append(letters[i + 13])
|
1446
|
-
piece3P = piece3P + letters[i + 13]
|
1447
|
-
piece4P = piece4P + letters[i + 13]
|
1448
1609
|
else:
|
1449
1610
|
# use triangular matrix as preconditioner for this dim
|
1450
1611
|
Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
|
1451
|
-
|
1452
|
-
piece2A = piece2A + letters[i + 13]
|
1453
|
-
piece3A = piece3A + letters[i]
|
1454
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1455
|
-
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1456
|
-
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
|
1457
|
-
exprGs.append(subscripts)
|
1458
|
-
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1459
|
-
piece1P.append(a + b)
|
1460
|
-
piece2P.append(a + c)
|
1461
|
-
piece3P = piece3P + c
|
1462
|
-
piece4P = piece4P + b
|
1463
|
-
|
1464
|
-
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
1465
|
-
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
1466
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1612
|
+
return Q
|
1467
1613
|
|
1468
1614
|
|
1469
|
-
@
|
1470
|
-
def psgd_balance_Q(
|
1471
|
-
norms =
|
1472
|
-
geometric_mean = norms
|
1473
|
-
|
1474
|
-
|
1615
|
+
@decorator_knowngood
|
1616
|
+
def psgd_balance_Q(Q):
|
1617
|
+
norms = [promote(q.norm(float("inf"))).log() for q in Q]
|
1618
|
+
geometric_mean = sum([n for n in norms]) / len(Q)
|
1619
|
+
for q, n in zip(Q, norms):
|
1620
|
+
q *= (geometric_mean - n).exp()
|
1475
1621
|
|
1476
1622
|
|
1477
|
-
@
|
1478
|
-
def
|
1479
|
-
u_norm =
|
1480
|
-
v_norm =
|
1481
|
-
scale = (u_norm / v_norm) ** 0.
|
1482
|
-
|
1483
|
-
|
1623
|
+
@decorator_knowngood
|
1624
|
+
def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
|
1625
|
+
u_norm = sum(u.square().sum().double() for u in U)
|
1626
|
+
v_norm = sum(v.square().sum().double() for v in V)
|
1627
|
+
scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
|
1628
|
+
scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
|
1629
|
+
stochastic_multiply_(U, [1 / scale] * len(U))
|
1630
|
+
stochastic_multiply_(V, [scale] * len(V))
|
1631
|
+
return multi_flatten((U, 1), (V, 1), (d, 0))
|
1484
1632
|
|
1485
1633
|
|
1486
1634
|
@decorator
|
1487
1635
|
def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
1488
1636
|
dtype = min_dtype([U, V, x])
|
1489
|
-
return x +
|
1637
|
+
return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1638
|
+
|
1639
|
+
|
1640
|
+
@decorator_knowngood
|
1641
|
+
def _compilable_d_step(
|
1642
|
+
d: Tensor,
|
1643
|
+
d_orig: List[Tensor],
|
1644
|
+
invQtv: Tensor,
|
1645
|
+
vector: Tensor,
|
1646
|
+
inverse_precond_vector: Tensor,
|
1647
|
+
hessian_vector: Tensor,
|
1648
|
+
precond_hessian_vector: Tensor,
|
1649
|
+
eps: Tensor,
|
1650
|
+
step: Tensor,
|
1651
|
+
delayed: bool,
|
1652
|
+
):
|
1653
|
+
precond_hessian_vector = promote(precond_hessian_vector)
|
1654
|
+
hessian_vector = promote(hessian_vector)
|
1655
|
+
vector = promote(vector)
|
1656
|
+
inverse_precond_vector = promote(inverse_precond_vector)
|
1657
|
+
invQtv = promote(invQtv)
|
1658
|
+
inverse_precond_vector = invQtv - inverse_precond_vector
|
1659
|
+
|
1660
|
+
nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
|
1661
|
+
|
1662
|
+
"""
|
1663
|
+
1) Sketching
|
1664
|
+
1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
|
1665
|
+
1.2) reduced-precision selection of largest element (halves memory traffic)
|
1666
|
+
2) Computation
|
1667
|
+
2.1) select relevant indices
|
1668
|
+
2.2) redo 1.1 in double precision for scalar values
|
1669
|
+
2.3) return high-precision normalized step-size
|
1670
|
+
overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
|
1671
|
+
improving precision
|
1672
|
+
"""
|
1673
|
+
a0 = promote(d) * precond_hessian_vector
|
1674
|
+
a1 = vector
|
1675
|
+
b0 = inverse_precond_vector / promote(d)
|
1676
|
+
b1 = hessian_vector
|
1677
|
+
|
1678
|
+
divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
|
1679
|
+
idx = divisor.bfloat16().flatten().argmax()
|
1680
|
+
a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
|
1681
|
+
b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
|
1682
|
+
divisor = (a * b).sqrt().clamp(min=eps)
|
1683
|
+
step = -step / divisor
|
1684
|
+
|
1685
|
+
# fused update(s)
|
1686
|
+
apply_flat_add(d_orig, nablaD, step)
|
1687
|
+
if not delayed:
|
1688
|
+
copy_stochastic_(d, promote(d) - nablaD * step)
|
1490
1689
|
|
1491
1690
|
|
1492
1691
|
def update_lra_precond_(
|
@@ -1498,13 +1697,14 @@ def update_lra_precond_(
|
|
1498
1697
|
eps: float,
|
1499
1698
|
step: float,
|
1500
1699
|
delayed: bool,
|
1700
|
+
precond_u: bool,
|
1501
1701
|
):
|
1502
1702
|
"""
|
1503
1703
|
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1504
1704
|
"""
|
1505
1705
|
U_orig, V_orig, d_orig = U, V, d
|
1506
1706
|
|
1507
|
-
U, V, d =
|
1707
|
+
U, V, d = _lra_flatten_and_balance(U, V, d)
|
1508
1708
|
|
1509
1709
|
dtype = min_dtype([U, V, vector, hessian_vector])
|
1510
1710
|
U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
|
@@ -1512,10 +1712,10 @@ def update_lra_precond_(
|
|
1512
1712
|
eps = scalar_guard(eps, vector)
|
1513
1713
|
|
1514
1714
|
Qh = low_rank_mm(U, V, d * hessian_vector)
|
1515
|
-
Ph =
|
1715
|
+
Ph = low_rank_mm(V, U, Qh)
|
1516
1716
|
rank = U.size(1)
|
1517
1717
|
|
1518
|
-
VtU =
|
1718
|
+
VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
|
1519
1719
|
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
1520
1720
|
IpVtU = I + VtU
|
1521
1721
|
invQtv = vector / d
|
@@ -1533,47 +1733,39 @@ def update_lra_precond_(
|
|
1533
1733
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1534
1734
|
|
1535
1735
|
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
1536
|
-
invPv =
|
1537
|
-
invPv = invPv / d
|
1736
|
+
invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
1538
1737
|
|
1539
|
-
|
1540
|
-
|
1541
|
-
divisor = divisor.add(eps).sqrt().max()
|
1542
|
-
d_step = step / divisor
|
1543
|
-
|
1544
|
-
apply_flat_add(d_orig, d * nablaD, -d_step)
|
1738
|
+
eps, step = scalar_guard(eps, step, vector)
|
1739
|
+
_compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
|
1545
1740
|
|
1546
1741
|
a, b = Qh, invQtv
|
1547
1742
|
|
1548
|
-
precond_u = random.random() < 0.5 # update either U or V, not both at the same time
|
1549
1743
|
precond = V if precond_u else U
|
1550
|
-
atV =
|
1551
|
-
btV =
|
1552
|
-
atVVt =
|
1553
|
-
btVVt =
|
1554
|
-
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()
|
1744
|
+
atV = compiled_einsum("b,br->r", a, precond) # o == one
|
1745
|
+
btV = compiled_einsum("b,br->r", b, precond)
|
1746
|
+
atVVt = compiled_einsum("r,br->b", atV, precond)
|
1747
|
+
btVVt = compiled_einsum("r,br->b", btV, precond)
|
1748
|
+
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
|
1555
1749
|
if precond_u:
|
1556
|
-
a =
|
1557
|
-
b =
|
1750
|
+
a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
|
1751
|
+
b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
|
1558
1752
|
else:
|
1559
|
-
a = a +
|
1560
|
-
b = b +
|
1561
|
-
a =
|
1562
|
-
b =
|
1753
|
+
a = a + compiled_einsum("br,r->b", V, atV)
|
1754
|
+
b = b + compiled_einsum("br,r->b", V, btV)
|
1755
|
+
a = compiled_einsum("b,r->br", a, atV)
|
1756
|
+
b = compiled_einsum("b,r->br", b, btV)
|
1563
1757
|
apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
|
1564
|
-
|
1565
1758
|
if not delayed:
|
1566
|
-
stochastic_add_([d], [d * nablaD], -d_step)
|
1567
1759
|
stochastic_add_([U if precond_u else V], [b - a], precond_step)
|
1568
1760
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1569
1761
|
|
1570
1762
|
|
1571
|
-
def lra_precond(U, V, d, g):
|
1763
|
+
def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
|
1572
1764
|
"""
|
1573
1765
|
As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
|
1574
1766
|
"""
|
1575
|
-
|
1576
|
-
return d * low_rank_mm(V, U,
|
1767
|
+
new_g = low_rank_mm(U, V, d * g)
|
1768
|
+
return d * low_rank_mm(V, U, new_g)
|
1577
1769
|
|
1578
1770
|
|
1579
1771
|
@decorator_knowngood
|
@@ -1584,16 +1776,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
|
|
1584
1776
|
|
1585
1777
|
|
1586
1778
|
@decorator_knowngood
|
1587
|
-
def
|
1588
|
-
|
1779
|
+
def _compilable_lra_update_(
|
1780
|
+
params: List[Tensor],
|
1781
|
+
update: List[Tensor],
|
1782
|
+
U: Tensor,
|
1783
|
+
V: Tensor,
|
1784
|
+
d: Tensor,
|
1785
|
+
lr: Tensor,
|
1786
|
+
decay: Tensor,
|
1787
|
+
caution: bool,
|
1788
|
+
grads: List[Tensor],
|
1789
|
+
):
|
1790
|
+
update = lra_precond(U, V, d, flatten(update))
|
1589
1791
|
start = 0
|
1590
1792
|
update = update.flatten()
|
1591
|
-
for p in params:
|
1793
|
+
for p, g in zip(params, grads):
|
1592
1794
|
size = p.numel()
|
1593
|
-
|
1795
|
+
update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
|
1594
1796
|
start += size
|
1595
1797
|
|
1596
1798
|
|
1799
|
+
def apply_lra_update(
|
1800
|
+
params: List[Tensor],
|
1801
|
+
update: Tensor,
|
1802
|
+
U: Tensor,
|
1803
|
+
V: Tensor,
|
1804
|
+
d: Tensor,
|
1805
|
+
lr: float,
|
1806
|
+
decay: float,
|
1807
|
+
caution: bool,
|
1808
|
+
grads: List[Tensor],
|
1809
|
+
):
|
1810
|
+
params, grads = list_guard(params, grads)
|
1811
|
+
lr, decay = scalar_guard(lr, decay, params[0])
|
1812
|
+
_compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
|
1813
|
+
|
1814
|
+
|
1597
1815
|
@decorator_knowngood
|
1598
1816
|
def apply_flat_update(params: List[Tensor], update: Tensor):
|
1599
1817
|
start = 0
|
@@ -1604,6 +1822,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
|
|
1604
1822
|
start += size
|
1605
1823
|
|
1606
1824
|
|
1825
|
+
@decorator_knowngood
|
1826
|
+
def zero_(x: List[Tensor]):
|
1827
|
+
for i in x:
|
1828
|
+
i.zero_()
|
1829
|
+
|
1830
|
+
|
1607
1831
|
@decorator_knowngood
|
1608
1832
|
def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
1609
1833
|
start = 0
|
@@ -1629,7 +1853,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
1629
1853
|
@decorator_knowngood
|
1630
1854
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1631
1855
|
last_dim = x[0].shape[-remaining:] if remaining else []
|
1632
|
-
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1856
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
|
1857
|
+
|
1858
|
+
|
1859
|
+
@decorator_knowngood
|
1860
|
+
def multi_flatten(*xs: Tuple[List[Tensor], int]):
|
1861
|
+
return [flatten(x, i) for x, i in xs]
|
1633
1862
|
|
1634
1863
|
|
1635
1864
|
@decorator_knowngood
|
@@ -1645,110 +1874,275 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
|
1645
1874
|
|
1646
1875
|
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1647
1876
|
md = min_dtype(args)
|
1648
|
-
return
|
1877
|
+
return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1649
1878
|
|
1650
1879
|
|
1651
1880
|
@decorator_knowngood
|
1652
1881
|
def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
|
1653
1882
|
triangular_qs = []
|
1883
|
+
conjB = promote(conjB)
|
1654
1884
|
for i, q in enumerate(Qs):
|
1655
1885
|
q = promote(q)
|
1656
1886
|
if q.dim() <= 1:
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
1887
|
+
if conjB.ndim == 0:
|
1888
|
+
conjB = conjB / q
|
1889
|
+
else:
|
1890
|
+
shape = [1] * conjB.ndim
|
1891
|
+
shape[i] = -1
|
1892
|
+
conjB = conjB / q.view(shape)
|
1660
1893
|
else:
|
1661
1894
|
triangular_qs.append((i, q))
|
1662
|
-
return triangular_qs
|
1895
|
+
return triangular_qs, conjB
|
1663
1896
|
|
1664
1897
|
|
1665
1898
|
@decorator_knowngood
|
1666
|
-
def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int,
|
1899
|
+
def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
|
1900
|
+
solved = solved.reshape(transposed_shape)
|
1901
|
+
solved = solved.transpose(-1, last_dim)
|
1667
1902
|
solved = solved.reshape(original_shape)
|
1668
|
-
solved.transpose(
|
1669
|
-
return solved.
|
1903
|
+
solved = solved.transpose(-1, new_dim)
|
1904
|
+
return solved.contiguous(), solved.shape
|
1905
|
+
|
1906
|
+
|
1907
|
+
def ndim_tuple(Q: list[Tensor]) -> tuple:
|
1908
|
+
return tuple(q.ndim for q in Q)
|
1670
1909
|
|
1671
1910
|
|
1672
|
-
def psgd_calc_A_and_conjB(
|
1673
|
-
|
1674
|
-
if order > 1:
|
1675
|
-
conjB = conjB.view_as(G).permute(*range(1, order), 0)
|
1676
|
-
conjB = conjB.to(promote(G.dtype))
|
1911
|
+
def psgd_calc_A_and_conjB(G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
1912
|
+
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
|
1677
1913
|
A = casted_einsum(exprA, *Q, G)
|
1678
1914
|
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
1679
|
-
original_shape = conjB.shape
|
1915
|
+
transposed_shape = original_shape = conjB.shape
|
1680
1916
|
prev_i = -1
|
1681
|
-
|
1682
|
-
|
1917
|
+
qs, conjB = _psgd_calc_scalars_(Q, conjB)
|
1918
|
+
for i, tri_q in qs:
|
1919
|
+
conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
|
1683
1920
|
prev_i = i
|
1684
1921
|
conjB = solve(tri_q, conjB, upper=True, left=False)
|
1685
|
-
conjB = _reshape_conjB(conjB, original_shape, prev_i,
|
1922
|
+
conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
|
1686
1923
|
return A, conjB
|
1687
1924
|
|
1688
1925
|
|
1689
1926
|
@decorator_knowngood
|
1690
|
-
def
|
1691
|
-
|
1692
|
-
|
1927
|
+
def _random_projection(x: Tensor, scale: Optional[Tensor]):
|
1928
|
+
if scale is None:
|
1929
|
+
scale = x.norm(float("inf")).clamp(min=1e-8)
|
1930
|
+
k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
|
1931
|
+
norm = x.square().sum(0)
|
1932
|
+
indices = torch.topk(norm, k, largest=True).indices
|
1933
|
+
return x.index_select(1, indices).contiguous() / scale, scale
|
1693
1934
|
|
1694
1935
|
|
1695
|
-
def
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
|
1703
|
-
|
1936
|
+
def max_singular_value_exact(A, use_lobpcg: bool = False):
|
1937
|
+
try:
|
1938
|
+
if use_lobpcg:
|
1939
|
+
A = A @ A.T
|
1940
|
+
eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
|
1941
|
+
return eigval[0].sqrt()
|
1942
|
+
else:
|
1943
|
+
return torch.linalg.svd(A, driver="gesvdj")[1].max() # == linalg.matrix_norm(A, ord=2)
|
1944
|
+
except torch.linalg.LinAlgError:
|
1945
|
+
return torch.zeros((), device=A.device, dtype=A.dtype)
|
1704
1946
|
|
1705
1947
|
|
1706
1948
|
@decorator_knowngood
|
1707
|
-
def
|
1708
|
-
|
1949
|
+
def max_singular_value_power_iter(A: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
|
1950
|
+
"""
|
1951
|
+
Rayleigh quotient of row with the largest norm + optional power iterations
|
1952
|
+
"""
|
1953
|
+
x_norm, max_idx = A.norm(dim=1).max(dim=0)
|
1954
|
+
x = A.index_select(0, max_idx).flatten().contiguous()
|
1955
|
+
A = A / x_norm
|
1956
|
+
x = x / x_norm
|
1957
|
+
for _ in range(iterations):
|
1958
|
+
x = A.T.mv(A.mv(x)) # A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
|
1959
|
+
x = x / x.norm()
|
1960
|
+
return (x @ A.T.mv(A.mv(x))).sqrt() * x_norm
|
1709
1961
|
|
1710
1962
|
|
1711
1963
|
@decorator_knowngood
|
1712
|
-
def
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1964
|
+
def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
1965
|
+
"""
|
1966
|
+
Adapted from @evanatyourservice
|
1967
|
+
"""
|
1968
|
+
Y, max_abs = _random_projection(A, max_abs)
|
1969
|
+
Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
|
1970
|
+
Q = Q / max_abs
|
1971
|
+
Z = A.T @ Q
|
1972
|
+
W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
|
1973
|
+
sketch_norm = max_singular_value_exact(Z.T @ W)
|
1974
|
+
return sketch_norm * max_abs
|
1716
1975
|
|
1717
1976
|
|
1718
1977
|
@decorator_knowngood
|
1719
|
-
def
|
1720
|
-
|
1978
|
+
def max_singular_value(
|
1979
|
+
A: Tensor, max_abs: Optional[Tensor], max_svd: int = 32, use_cholesky: bool = False, power_iter: int = 0
|
1980
|
+
) -> Tensor:
|
1981
|
+
if min(A.shape) <= max_svd:
|
1982
|
+
return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
|
1983
|
+
if use_cholesky or power_iter < 0:
|
1984
|
+
return max_singular_value_cholesky(A, max_abs)
|
1985
|
+
return max_singular_value_power_iter(A, None, iterations=power_iter)
|
1986
|
+
|
1987
|
+
|
1988
|
+
@decorator_knowngood
|
1989
|
+
def _psgd_default_preconditioner_grad(
|
1990
|
+
terms: List[Tuple[Tensor, Tensor]],
|
1991
|
+
Q: List[Tensor],
|
1992
|
+
) -> List[Tensor]:
|
1993
|
+
out = []
|
1994
|
+
for q, (x, y) in zip(Q, terms):
|
1995
|
+
x = promote(x)
|
1996
|
+
y = promote(y)
|
1997
|
+
update = x - y
|
1998
|
+
if q.ndim < 2:
|
1999
|
+
update = q * update
|
2000
|
+
else:
|
2001
|
+
update = (q @ update).triu()
|
2002
|
+
out.append(update)
|
2003
|
+
return out
|
1721
2004
|
|
1722
2005
|
|
1723
2006
|
@decorator_knowngood
|
1724
|
-
def
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
2007
|
+
def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
|
2008
|
+
if isinstance(Q[0], tuple):
|
2009
|
+
psgd_balance_Q([o[1] for o in Q])
|
2010
|
+
return line_to_triu(Q, symmetric_output)
|
2011
|
+
psgd_balance_Q(Q)
|
2012
|
+
return Q
|
2013
|
+
|
2014
|
+
|
2015
|
+
@functools.lru_cache(maxsize=None)
|
2016
|
+
def calcG_expr(q_dim, g_dim):
|
2017
|
+
exprs = []
|
2018
|
+
base = einsum_base[:g_dim]
|
2019
|
+
for i, q in enumerate(q_dim):
|
2020
|
+
new = list(base)
|
2021
|
+
if q == 2:
|
2022
|
+
new[i] = "Z"
|
2023
|
+
out = f"{base[i]}Z"
|
2024
|
+
else:
|
2025
|
+
out = base[i]
|
2026
|
+
exprs.append(f"{base},{''.join(new)}->{out}")
|
2027
|
+
return exprs
|
1728
2028
|
|
1729
2029
|
|
1730
2030
|
@decorator
|
1731
|
-
def psgd_update_precond(
|
2031
|
+
def psgd_update_precond(
|
2032
|
+
G: Tensor,
|
2033
|
+
precond_lr: float,
|
2034
|
+
oq: "TriuOrLine",
|
2035
|
+
store_triu_as_line: bool,
|
2036
|
+
velocity: Optional[List[Tensor]],
|
2037
|
+
beta2: float,
|
2038
|
+
ortho_method: Optional[str],
|
2039
|
+
V: Tensor,
|
2040
|
+
running_lower_bound: List[Tensor],
|
2041
|
+
lower_bount_beta: float,
|
2042
|
+
power_iter: int,
|
2043
|
+
) -> None:
|
1732
2044
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1733
|
-
|
1734
|
-
|
1735
|
-
precond_lr = scalar_guard(precond_lr, G)
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
2045
|
+
Q = _balance_to_triu(oq)
|
2046
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2047
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2048
|
+
|
2049
|
+
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
|
2050
|
+
terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
|
2051
|
+
del A, conjB, V
|
2052
|
+
updates = _psgd_default_preconditioner_grad(terms, Q)
|
2053
|
+
_psgd_precond_update_(
|
2054
|
+
updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2055
|
+
)
|
2056
|
+
return None
|
2057
|
+
|
2058
|
+
|
2059
|
+
@decorator_knowngood
|
2060
|
+
def _psgd_precond_update_(
|
2061
|
+
matmuled: List[Optional[Tensor]],
|
2062
|
+
Q: "TriuOrLine",
|
2063
|
+
running_lower_bound: List[Tensor],
|
2064
|
+
lower_bount_beta: Tensor,
|
2065
|
+
precond_lr: Tensor,
|
2066
|
+
store_triu_as_line: bool,
|
2067
|
+
power_iter: int,
|
2068
|
+
):
|
2069
|
+
for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
|
2070
|
+
if isinstance(oq, tuple):
|
2071
|
+
oq = oq[1]
|
2072
|
+
|
2073
|
+
q = promote(oq)
|
2074
|
+
if update.ndim < 2:
|
2075
|
+
lb = update.norm(float("inf"))
|
1744
2076
|
else:
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
|
1749
|
-
|
2077
|
+
lb = max_singular_value(update, None, power_iter=power_iter)
|
2078
|
+
update = promote(update)
|
2079
|
+
if store_triu_as_line:
|
2080
|
+
update = triu_to_line([update])[0][1]
|
2081
|
+
|
2082
|
+
lb = promote(lb)
|
2083
|
+
lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
|
2084
|
+
copy_stochastic_(lb_state, lb)
|
2085
|
+
copy_stochastic_(oq, q - update / lb * precond_lr)
|
2086
|
+
|
2087
|
+
|
2088
|
+
@decorator_knowngood
|
2089
|
+
def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
|
2090
|
+
"""
|
2091
|
+
I: Identity
|
2092
|
+
U: Update / gg / target
|
2093
|
+
Q: q, preconditioner
|
2094
|
+
scale: scalar scale
|
2095
|
+
---
|
2096
|
+
U = T * scale - I
|
2097
|
+
F = I - U # = 2I - U * scale
|
2098
|
+
O = F @ Q @ F - Q
|
2099
|
+
"""
|
2100
|
+
out = []
|
2101
|
+
for gg, q in zip(GG, Q):
|
2102
|
+
if gg.ndim < 2:
|
2103
|
+
scale = max(1, gg.numel()) / numel
|
2104
|
+
target = promote(gg)
|
2105
|
+
update = target * scale - 1
|
2106
|
+
out.append(q - (1 - update) * q * (1 - update))
|
1750
2107
|
else:
|
1751
|
-
|
2108
|
+
scale = gg.size(0) / numel
|
2109
|
+
gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
|
2110
|
+
update = q - gg @ q @ gg
|
2111
|
+
out.append(update + update.T) # make matrix symmetric
|
2112
|
+
return out
|
2113
|
+
|
2114
|
+
|
2115
|
+
@decorator
|
2116
|
+
def inverse_free_psgd_update_precond(
|
2117
|
+
G: Tensor,
|
2118
|
+
precond_lr: float,
|
2119
|
+
oq: List[Tensor],
|
2120
|
+
store_triu_as_line: bool,
|
2121
|
+
velocity: Optional[List[Tensor]],
|
2122
|
+
beta2: float,
|
2123
|
+
ortho_method: Optional[str],
|
2124
|
+
V: None,
|
2125
|
+
running_lower_bound: List[Tensor],
|
2126
|
+
lower_bount_beta: float,
|
2127
|
+
power_iter: int,
|
2128
|
+
) -> Tensor:
|
2129
|
+
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
2130
|
+
assert V is None
|
2131
|
+
assert ortho_method is None
|
2132
|
+
assert velocity is None
|
2133
|
+
del V, ortho_method, velocity
|
2134
|
+
|
2135
|
+
Q = _balance_to_triu(oq, True)
|
2136
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2137
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2138
|
+
|
2139
|
+
G = psgd_precond_grad(G, Q)
|
2140
|
+
terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
|
2141
|
+
matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
|
2142
|
+
_psgd_precond_update_(
|
2143
|
+
matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2144
|
+
)
|
2145
|
+
return G
|
1752
2146
|
|
1753
2147
|
|
1754
2148
|
@decorator_knowngood
|
@@ -1785,6 +2179,34 @@ def rmsnorm_clip_(x, clip_at: float = 1.0):
|
|
1785
2179
|
return _compilable_rmsnorm_clip_(x, clip_at)
|
1786
2180
|
|
1787
2181
|
|
2182
|
+
@decorator_knowngood
|
2183
|
+
def _compilable_global_rmsnorm_clip_(x, clip_at):
|
2184
|
+
x = list(map(promote, x))
|
2185
|
+
norm = sum([x.square().sum() for x in x]) / sum([x.numel() for x in x])
|
2186
|
+
norm = norm**0.5
|
2187
|
+
norm = norm.clamp(min=clip_at)
|
2188
|
+
return torch._foreach_div(x, norm)
|
2189
|
+
|
2190
|
+
|
2191
|
+
@decorator_knowngood
|
2192
|
+
def _compilable_global_l2norm_clip_(x, clip_at):
|
2193
|
+
x = list(map(promote, x))
|
2194
|
+
norm = sum([x.square().sum() for x in x])
|
2195
|
+
norm = norm**0.5
|
2196
|
+
norm = norm.clamp(min=clip_at)
|
2197
|
+
return torch._foreach_div(x, norm)
|
2198
|
+
|
2199
|
+
|
2200
|
+
def global_rmsnorm_clip(x, clip_at: float = 1.0):
|
2201
|
+
x = list_guard(x)
|
2202
|
+
return _compilable_global_rmsnorm_clip_(x, clip_at)
|
2203
|
+
|
2204
|
+
|
2205
|
+
def global_l2norm_clip(x, clip_at: float = 1.0):
|
2206
|
+
x = list_guard(x)
|
2207
|
+
return _compilable_global_rmsnorm_clip_(x, clip_at)
|
2208
|
+
|
2209
|
+
|
1788
2210
|
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1789
2211
|
x = list_guard(x)
|
1790
2212
|
return _compilable_rmsnorm_clip_(x, clip_at)
|
@@ -1862,6 +2284,17 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
|
1862
2284
|
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1863
2285
|
|
1864
2286
|
|
2287
|
+
@decorator_knowngood
|
2288
|
+
def _compilable_weight_decay_to_init_(p, init, weight_decay):
|
2289
|
+
_lerp(p, promote(init), 1 - weight_decay)
|
2290
|
+
|
2291
|
+
|
2292
|
+
def weight_decay_to_init_(p, init, weight_decay):
|
2293
|
+
p, init = list_guard(p, init)
|
2294
|
+
weight_decay = scalar_guard(weight_decay, p[0])
|
2295
|
+
_compilable_weight_decay_to_ema_(p, init, weight_decay)
|
2296
|
+
|
2297
|
+
|
1865
2298
|
@decorator_knowngood
|
1866
2299
|
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1867
2300
|
ema32 = _lerp(ema, p, ema_decay)
|
@@ -1920,35 +2353,25 @@ def triu_to_line(Q_list: List[Tensor]):
|
|
1920
2353
|
if q.dim() < 2:
|
1921
2354
|
out.append((None, q))
|
1922
2355
|
else:
|
1923
|
-
out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
|
2356
|
+
out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
|
1924
2357
|
return out
|
1925
2358
|
|
1926
2359
|
|
1927
|
-
|
1928
|
-
|
1929
|
-
assert n * (n + 1) == 2 * numel
|
1930
|
-
return n, n
|
1931
|
-
|
1932
|
-
|
1933
|
-
@decorator
|
1934
|
-
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
|
2360
|
+
@decorator_knowngood
|
2361
|
+
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
|
1935
2362
|
new = []
|
1936
2363
|
for shape, q in Q_list:
|
1937
2364
|
if shape is not None:
|
1938
|
-
|
1939
|
-
|
1940
|
-
x
|
1941
|
-
|
2365
|
+
x, y = torch.triu_indices(*shape, device=q.device)
|
2366
|
+
q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
|
2367
|
+
q_mat[x, y] = q
|
2368
|
+
if symmetric_output:
|
2369
|
+
q_mat[y, x] = q
|
2370
|
+
q = q_mat
|
1942
2371
|
new.append(q)
|
1943
2372
|
return new
|
1944
2373
|
|
1945
2374
|
|
1946
|
-
def update_triu_(q_state, materialised):
|
1947
|
-
for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
|
1948
|
-
assert shape0 == shape1
|
1949
|
-
copy_stochastic_(q, m)
|
1950
|
-
|
1951
|
-
|
1952
2375
|
_warned = set()
|
1953
2376
|
|
1954
2377
|
|
@@ -1971,52 +2394,118 @@ def psgd_should_update(
|
|
1971
2394
|
return int(group[name]) > int(cumulative_prob)
|
1972
2395
|
|
1973
2396
|
|
2397
|
+
@functools.lru_cache(maxsize=None)
|
2398
|
+
def cached_precond_grad_expr(Q_dim, grad_dim):
|
2399
|
+
expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
|
2400
|
+
expr = ",".join(expr)
|
2401
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2402
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2403
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
2404
|
+
|
2405
|
+
|
1974
2406
|
@decorator_knowngood
|
1975
2407
|
def precond_grad_cached_(
|
1976
|
-
|
2408
|
+
ea: Tensor,
|
2409
|
+
cached_q: List[Tensor],
|
2410
|
+
caution: bool = False,
|
2411
|
+
grad: Optional[Tensor] = None,
|
2412
|
+
cast: bool = True,
|
1977
2413
|
):
|
1978
2414
|
if caution:
|
1979
2415
|
ea = _compilable_cautioning(grad, ea)
|
1980
2416
|
md = min_dtype(list(cached_q) + [ea])
|
1981
2417
|
args = [q.to(md) for q in cached_q]
|
1982
2418
|
args = args + [ea.to(md)]
|
1983
|
-
|
2419
|
+
expr = cached_precond_grad_expr(ndim_tuple(cached_q), grad.ndim)
|
2420
|
+
new = compiled_einsum(expr, *args)
|
1984
2421
|
if cast:
|
1985
2422
|
return new.to(ea.dtype)
|
1986
2423
|
return new
|
1987
2424
|
|
1988
2425
|
|
2426
|
+
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
|
2427
|
+
|
2428
|
+
|
1989
2429
|
@decorator_knowngood
|
1990
|
-
def _compilable_fused_precond_grad_cached_(
|
1991
|
-
precond = precond_grad_cached_(
|
2430
|
+
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2431
|
+
precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
|
1992
2432
|
update_param_(param, precond, lr, decay, caution=False)
|
1993
2433
|
|
1994
2434
|
|
1995
|
-
def fused_precond_grad_cached_(
|
2435
|
+
def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
1996
2436
|
lr = scalar_guard(lr, param[0])
|
1997
|
-
_compilable_fused_precond_grad_cached_(
|
2437
|
+
_compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
|
2438
|
+
|
2439
|
+
|
2440
|
+
@functools.lru_cache(maxsize=None)
|
2441
|
+
def precond_grad_expr(Q_dim, grad_dim):
|
2442
|
+
expr = [
|
2443
|
+
f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
|
2444
|
+
]
|
2445
|
+
expr = ",".join(expr)
|
2446
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2447
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2448
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
1998
2449
|
|
1999
2450
|
|
2000
2451
|
@decorator_knowngood
|
2001
|
-
def psgd_precond_grad(
|
2452
|
+
def psgd_precond_grad(
|
2453
|
+
ea: Tensor,
|
2454
|
+
preconds: TriuOrLine,
|
2455
|
+
caution: bool = False,
|
2456
|
+
grad: Optional[Tensor] = None,
|
2457
|
+
store_triu_as_line: bool = False,
|
2458
|
+
symmetric_output: bool = False,
|
2459
|
+
):
|
2002
2460
|
if caution:
|
2003
2461
|
ea = _compilable_cautioning(grad, ea)
|
2462
|
+
if store_triu_as_line:
|
2463
|
+
preconds = line_to_triu(preconds, symmetric_output)
|
2004
2464
|
md = min_dtype(list(preconds) + [ea])
|
2005
2465
|
args = [q.to(md) for q in preconds]
|
2006
|
-
|
2007
|
-
new =
|
2466
|
+
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
|
2467
|
+
new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
|
2008
2468
|
return new.to(ea.dtype)
|
2009
2469
|
|
2010
2470
|
|
2011
2471
|
@decorator_knowngood
|
2012
|
-
def _compilable_fused_psgd_precond_grad(
|
2013
|
-
|
2472
|
+
def _compilable_fused_psgd_precond_grad(
|
2473
|
+
ea: Tensor,
|
2474
|
+
param,
|
2475
|
+
lr,
|
2476
|
+
grad,
|
2477
|
+
decay,
|
2478
|
+
caution,
|
2479
|
+
preconds: TriuOrLine,
|
2480
|
+
store_triu_as_line: bool = False,
|
2481
|
+
symmetric_output: bool = False,
|
2482
|
+
):
|
2483
|
+
precond = psgd_precond_grad(
|
2484
|
+
ea,
|
2485
|
+
preconds,
|
2486
|
+
caution=caution,
|
2487
|
+
grad=grad,
|
2488
|
+
store_triu_as_line=store_triu_as_line,
|
2489
|
+
symmetric_output=symmetric_output,
|
2490
|
+
)
|
2014
2491
|
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
2015
2492
|
|
2016
2493
|
|
2017
|
-
def fused_psgd_precond_grad(
|
2494
|
+
def fused_psgd_precond_grad(
|
2495
|
+
ea: Tensor,
|
2496
|
+
param,
|
2497
|
+
lr,
|
2498
|
+
grad,
|
2499
|
+
decay,
|
2500
|
+
caution,
|
2501
|
+
preconds: TriuOrLine,
|
2502
|
+
store_triu_as_line: bool = False,
|
2503
|
+
symmetric_output: bool = False,
|
2504
|
+
):
|
2018
2505
|
lr = scalar_guard(lr, param[0])
|
2019
|
-
_compilable_fused_psgd_precond_grad(
|
2506
|
+
_compilable_fused_psgd_precond_grad(
|
2507
|
+
ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
|
2508
|
+
)
|
2020
2509
|
|
2021
2510
|
|
2022
2511
|
@decorator_knowngood
|
@@ -2068,7 +2557,15 @@ def caution(g, update):
|
|
2068
2557
|
return _compilable_cautioning(g, update)
|
2069
2558
|
|
2070
2559
|
|
2071
|
-
def
|
2560
|
+
def _inner_precond_update_prob_schedule(
|
2561
|
+
n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2562
|
+
):
|
2563
|
+
return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
|
2564
|
+
|
2565
|
+
|
2566
|
+
def precond_update_prob_schedule(
|
2567
|
+
max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2568
|
+
):
|
2072
2569
|
"""Anneal preconditioner update probability during beginning of training.
|
2073
2570
|
|
2074
2571
|
PSGD benefits from more preconditioner updates at the beginning of training,
|
@@ -2079,11 +2576,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
|
|
2079
2576
|
`min_prob` by ~4000 steps. Default settings work very well for most models and
|
2080
2577
|
training regimes.
|
2081
2578
|
"""
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
return _schedule
|
2579
|
+
return functools.partial(
|
2580
|
+
_inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
|
2581
|
+
)
|
2087
2582
|
|
2088
2583
|
|
2089
2584
|
def merge_group(group, *tensors):
|
@@ -2217,3 +2712,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
|
2217
2712
|
def disable_caution_scaling():
|
2218
2713
|
global _compilable_cautioning
|
2219
2714
|
_compilable_cautioning = _compilable_caution_no_scale
|
2715
|
+
|
2716
|
+
|
2717
|
+
@decorator_knowngood
|
2718
|
+
def sam_step(parameters, ball_size, adaptive: bool = True):
|
2719
|
+
old_params = []
|
2720
|
+
for p in parameters:
|
2721
|
+
old_params.append(p.detach().clone())
|
2722
|
+
grad = promote(p.grad)
|
2723
|
+
if adaptive:
|
2724
|
+
grad = grad * promote(p).square()
|
2725
|
+
stochastic_add_(p.data, grad, ball_size)
|
2726
|
+
p.grad.zero_()
|
2727
|
+
return old_params
|