heavyball 1.6.3__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +515 -100
- heavyball/chainable.py +487 -156
- heavyball/optimizations/__init__.py +38 -0
- heavyball/optimizations/integrator.py +169 -0
- heavyball/optimizations/optimizations.py +329 -0
- heavyball/utils.py +780 -241
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/METADATA +3 -2
- heavyball-1.7.1.dist-info/RECORD +11 -0
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/WHEEL +1 -1
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info/licenses}/LICENSE +1 -1
- heavyball-1.6.3.dist-info/RECORD +0 -8
- {heavyball-1.6.3.dist-info → heavyball-1.7.1.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
+
import contextlib
|
1
2
|
import functools
|
2
3
|
import gc
|
4
|
+
import inspect
|
3
5
|
import math
|
4
6
|
import random
|
7
|
+
import re
|
5
8
|
import string
|
6
9
|
import warnings
|
7
|
-
from typing import List, Optional, Tuple,
|
8
|
-
from unittest.mock import patch
|
10
|
+
from typing import Callable, List, Optional, Tuple, Union
|
9
11
|
|
10
12
|
import numpy as np
|
11
13
|
import torch
|
@@ -15,19 +17,22 @@ from torch._dynamo.exc import TorchDynamoException
|
|
15
17
|
from torch.backends import cudnn, opt_einsum
|
16
18
|
from torch.utils._pytree import tree_map
|
17
19
|
|
18
|
-
config.cache_size_limit = 2
|
19
|
-
|
20
|
-
np.warnings = warnings
|
20
|
+
config.cache_size_limit = 2**16
|
21
21
|
|
22
22
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
23
|
dynamic = False
|
24
24
|
compile_mode_recommended_to_none = None
|
25
|
-
zeroth_power_mode =
|
25
|
+
zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
|
26
26
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
27
|
+
_cudnn_double_backward_pattern = re.compile(
|
28
|
+
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
29
|
+
)
|
30
|
+
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
31
|
+
_fd_error = (
|
32
|
+
"You can accelerate startup by globally enabling finite_differences first " #
|
33
|
+
"(via opt.finite_differences=True or by subclassing it)\n"
|
34
|
+
"Original Error: "
|
35
|
+
)
|
31
36
|
|
32
37
|
|
33
38
|
def decorator(func):
|
@@ -35,7 +40,6 @@ def decorator(func):
|
|
35
40
|
|
36
41
|
@functools.wraps(func)
|
37
42
|
def _fn(*args, **kwargs):
|
38
|
-
disable = compile_mode_recommended_to_none is None
|
39
43
|
if is_compiling() or compile_mode_recommended_to_none is None:
|
40
44
|
return func(*args, **kwargs)
|
41
45
|
nonlocal compiled
|
@@ -65,8 +69,17 @@ einsum_base = string.ascii_lowercase
|
|
65
69
|
|
66
70
|
|
67
71
|
@decorator_knowngood
|
68
|
-
def _compilable_schedule_free_(
|
69
|
-
|
72
|
+
def _compilable_schedule_free_(
|
73
|
+
p: List[Tensor],
|
74
|
+
z: List[Tensor],
|
75
|
+
ckp1: Tensor,
|
76
|
+
update: List[Tensor],
|
77
|
+
lr: Tensor,
|
78
|
+
beta1: Tensor,
|
79
|
+
decay: float,
|
80
|
+
grad: List[Tensor],
|
81
|
+
caution,
|
82
|
+
):
|
70
83
|
for op, oz, u_, g_ in zip(p, z, update, grad):
|
71
84
|
u_ = u_.view_as(op)
|
72
85
|
p_, z_, u_ = map(promote, (op, oz, u_))
|
@@ -81,9 +94,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
|
|
81
94
|
copy_stochastic_(oz, z_)
|
82
95
|
|
83
96
|
|
84
|
-
def schedule_free_(
|
85
|
-
|
86
|
-
|
97
|
+
def schedule_free_(
|
98
|
+
lr: float,
|
99
|
+
weight_lr_power: float,
|
100
|
+
weight_sum: float,
|
101
|
+
beta1: float,
|
102
|
+
parameters: List[Tensor],
|
103
|
+
z: List[Tensor],
|
104
|
+
update: List[Tensor],
|
105
|
+
grad: List[Tensor],
|
106
|
+
caution: bool = False,
|
107
|
+
r: float = 0.0,
|
108
|
+
step: int = 0,
|
109
|
+
decay: float = 0.0,
|
110
|
+
):
|
87
111
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
88
112
|
weight_sum = weight_sum + weight
|
89
113
|
|
@@ -156,7 +180,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
156
180
|
|
157
181
|
|
158
182
|
def beta_debias(beta, step):
|
159
|
-
return 1 - (1 - beta) / (1 - beta
|
183
|
+
return 1 - (1 - beta) / (1 - beta**step)
|
160
184
|
|
161
185
|
|
162
186
|
def eps_sqrt(item, eps):
|
@@ -164,8 +188,9 @@ def eps_sqrt(item, eps):
|
|
164
188
|
|
165
189
|
|
166
190
|
@decorator_knowngood
|
167
|
-
def _compilable_exp_avg_sq_(
|
168
|
-
|
191
|
+
def _compilable_exp_avg_sq_(
|
192
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
|
193
|
+
):
|
169
194
|
g32 = promote(grad)
|
170
195
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
171
196
|
|
@@ -226,8 +251,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
|
|
226
251
|
copy_stochastic_list_(gradients, g32)
|
227
252
|
|
228
253
|
|
229
|
-
def adaptive_gradient_clipping_(
|
230
|
-
|
254
|
+
def adaptive_gradient_clipping_(
|
255
|
+
parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8
|
256
|
+
):
|
231
257
|
if clip_val <= 0:
|
232
258
|
return gradients
|
233
259
|
parameters, gradients = list_guard(parameters, gradients)
|
@@ -253,23 +279,24 @@ def clean():
|
|
253
279
|
|
254
280
|
|
255
281
|
def _ignore_warning(msg):
|
256
|
-
warnings.filterwarnings(
|
282
|
+
warnings.filterwarnings("ignore", f".*{msg}.*")
|
257
283
|
|
258
284
|
|
259
|
-
def set_torch(benchmark_limit: int = 32):
|
285
|
+
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
260
286
|
cudnn.benchmark = True
|
261
287
|
cudnn.deterministic = False
|
262
288
|
cudnn.benchmark_limit = benchmark_limit
|
263
289
|
torch.use_deterministic_algorithms(False)
|
264
290
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
265
|
-
opt_einsum.
|
266
|
-
opt_einsum.strategy = "auto"
|
291
|
+
opt_einsum.set_flags(True, einsum_strategy)
|
267
292
|
|
268
293
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
269
294
|
_ignore_warning(
|
270
|
-
|
295
|
+
"Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
|
296
|
+
)
|
271
297
|
_ignore_warning(
|
272
|
-
|
298
|
+
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
|
+
)
|
273
300
|
|
274
301
|
|
275
302
|
@decorator
|
@@ -277,7 +304,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
277
304
|
assert len(G.shape) == 2
|
278
305
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
279
306
|
X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
|
280
|
-
X /=
|
307
|
+
X /= X.norm() + eps # ensure top singular value <= 1
|
281
308
|
if G.size(0) > G.size(1):
|
282
309
|
X = X.T
|
283
310
|
for _ in range(steps):
|
@@ -290,10 +317,10 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
290
317
|
|
291
318
|
|
292
319
|
def ortho(x):
|
293
|
-
if zeroth_power_mode ==
|
320
|
+
if zeroth_power_mode == "qr":
|
294
321
|
return torch.linalg.qr(x).Q
|
295
|
-
if zeroth_power_mode ==
|
296
|
-
u,
|
322
|
+
if zeroth_power_mode == "svd":
|
323
|
+
u, _s, v = torch.linalg.svd(x)
|
297
324
|
return u @ v.T
|
298
325
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
299
326
|
|
@@ -351,12 +378,12 @@ def _compilable_grafting(magnitude, direction):
|
|
351
378
|
|
352
379
|
@decorator_knowngood
|
353
380
|
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
354
|
-
if mode ==
|
381
|
+
if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
|
355
382
|
y = zeropower_via_newtonschulz5(x, 5)
|
356
|
-
elif mode ==
|
383
|
+
elif mode == "qr":
|
357
384
|
y = torch.linalg.qr(promote(x)).Q
|
358
|
-
elif mode ==
|
359
|
-
u,
|
385
|
+
elif mode == "svd":
|
386
|
+
u, _s, v = torch.linalg.svd(promote(x))
|
360
387
|
y = u @ v.T
|
361
388
|
else:
|
362
389
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
@@ -403,7 +430,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
403
430
|
q_old = promote(q.data)
|
404
431
|
|
405
432
|
tmp = m @ q_old
|
406
|
-
est_eig = torch.einsum(
|
433
|
+
est_eig = torch.einsum("ij,ij->j", q_old, tmp)
|
407
434
|
sort_idx = torch.argsort(est_eig, descending=True)
|
408
435
|
|
409
436
|
tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
|
@@ -415,19 +442,20 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
415
442
|
return
|
416
443
|
|
417
444
|
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
418
|
-
in_str = einsum_base[:exp_avg.dim()]
|
419
|
-
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
445
|
+
in_str = einsum_base[: exp_avg.dim()]
|
446
|
+
out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()]
|
420
447
|
|
421
448
|
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
422
449
|
if not from_shampoo:
|
423
450
|
return
|
424
451
|
|
425
|
-
to_shampoo =
|
426
|
-
out_str =
|
452
|
+
to_shampoo = ",".join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
|
453
|
+
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
427
454
|
|
428
|
-
subscripts = f
|
429
|
-
exp_avg_new = torch.einsum(
|
430
|
-
|
455
|
+
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
456
|
+
exp_avg_new = torch.einsum(
|
457
|
+
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
|
+
)
|
431
459
|
copy_stochastic_(exp_avg, exp_avg_new)
|
432
460
|
|
433
461
|
for q, q_new in zip(Q, new_qs):
|
@@ -453,11 +481,11 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
453
481
|
while True:
|
454
482
|
try:
|
455
483
|
eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
|
456
|
-
|
484
|
+
_eigval, eigvec = torch.linalg.eigh(m + eps * eye)
|
457
485
|
eigvec = eigvec.to(device=device, dtype=dtype)
|
458
486
|
break
|
459
487
|
except torch.OutOfMemoryError:
|
460
|
-
if m.device.type ==
|
488
|
+
if m.device.type == "cpu":
|
461
489
|
raise
|
462
490
|
else:
|
463
491
|
m = m.cpu()
|
@@ -489,21 +517,21 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
489
517
|
|
490
518
|
def get_beta1(group):
|
491
519
|
beta = None
|
492
|
-
if
|
493
|
-
beta = group[
|
494
|
-
if beta is None and
|
495
|
-
beta = group[
|
520
|
+
if "beta" in group:
|
521
|
+
beta = group["beta"]
|
522
|
+
if beta is None and "betas" in group:
|
523
|
+
beta = group["betas"][0]
|
496
524
|
if beta is None:
|
497
525
|
raise ValueError("Beta not found in group.")
|
498
526
|
return beta
|
499
527
|
|
500
528
|
|
501
529
|
def get_beta2(group):
|
502
|
-
if
|
530
|
+
if "palm" in group and group["palm"] is True and "beta2_scale" in group:
|
503
531
|
step = max(group.get("step", 1), 1)
|
504
|
-
return 1 - step ** -group[
|
505
|
-
if
|
506
|
-
return group[
|
532
|
+
return 1 - step ** -group["beta2_scale"]
|
533
|
+
if "betas" in group:
|
534
|
+
return group["betas"][1]
|
507
535
|
raise ValueError("Beta2 not found in group.")
|
508
536
|
|
509
537
|
|
@@ -554,6 +582,20 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
|
|
554
582
|
_compilable_stochastic_add_(x, y, alpha)
|
555
583
|
|
556
584
|
|
585
|
+
@decorator_knowngood
|
586
|
+
def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Tensor, divisor: Tensor):
|
587
|
+
for x_, y_ in zip(x, y):
|
588
|
+
x32 = promote(x_)
|
589
|
+
y32 = promote(y_)
|
590
|
+
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
591
|
+
|
592
|
+
|
593
|
+
def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
|
594
|
+
x, y = list_guard(x, y)
|
595
|
+
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
|
+
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
|
+
|
598
|
+
|
557
599
|
@decorator_knowngood
|
558
600
|
def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
559
601
|
for x_, y_ in zip(x, y):
|
@@ -580,9 +622,9 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
580
622
|
if not isinstance(m, Tensor):
|
581
623
|
continue
|
582
624
|
b = einsum_base[idx]
|
583
|
-
g0 = einsum_base[:grad.dim()]
|
625
|
+
g0 = einsum_base[: grad.dim()]
|
584
626
|
g1 = g0.replace(b, b.upper())
|
585
|
-
outer_product = torch.einsum(f
|
627
|
+
outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
|
586
628
|
stochastic_lerp_(m, outer_product, 1 - beta)
|
587
629
|
|
588
630
|
|
@@ -602,6 +644,20 @@ def promote(x):
|
|
602
644
|
return x
|
603
645
|
|
604
646
|
|
647
|
+
def promote_detach(x, should_promote):
|
648
|
+
if x is None:
|
649
|
+
return x
|
650
|
+
if should_promote:
|
651
|
+
x = promote(x)
|
652
|
+
return x.detach()
|
653
|
+
|
654
|
+
|
655
|
+
def detach(x):
|
656
|
+
if isinstance(x, Tensor):
|
657
|
+
return x.detach()
|
658
|
+
return x
|
659
|
+
|
660
|
+
|
605
661
|
def min_dtype(xs: List[Tensor]):
|
606
662
|
dtypes = [x.dtype for x in xs]
|
607
663
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
@@ -623,19 +679,19 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
|
623
679
|
"""
|
624
680
|
Initializes the preconditioner matrices (L and R in the paper).
|
625
681
|
"""
|
626
|
-
state[
|
682
|
+
state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
627
683
|
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
628
684
|
for sh in grad.shape:
|
629
685
|
if sh > max_precond_dim or sh == 1:
|
630
686
|
# via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
|
631
|
-
state[
|
687
|
+
state["GG"].append(None)
|
632
688
|
else:
|
633
|
-
state[
|
689
|
+
state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
634
690
|
else:
|
635
|
-
state[
|
691
|
+
state["GG"].append(None)
|
636
692
|
|
637
|
-
update_ggt(grad, state[
|
638
|
-
state[
|
693
|
+
update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0)
|
694
|
+
state["Q"] = get_orthogonal_matrix(state["GG"])
|
639
695
|
|
640
696
|
|
641
697
|
@decorator
|
@@ -646,34 +702,45 @@ def project(grad, Q, back: bool):
|
|
646
702
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
647
703
|
:return:
|
648
704
|
"""
|
649
|
-
param = einsum_base[:grad.dim()]
|
650
|
-
preconditioners = ",".join([(g + g.upper())[
|
705
|
+
param = einsum_base[: grad.dim()]
|
706
|
+
preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
651
707
|
if preconditioners:
|
652
|
-
out =
|
653
|
-
out = torch.einsum(f
|
708
|
+
out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
|
709
|
+
out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
|
654
710
|
grad = out.to(grad.dtype)
|
655
711
|
return grad
|
656
712
|
|
657
713
|
|
658
|
-
|
659
|
-
|
660
|
-
|
714
|
+
@contextlib.contextmanager
|
715
|
+
def patch_backward():
|
716
|
+
@contextlib.contextmanager
|
717
|
+
def _inner(module):
|
718
|
+
original = module.backward
|
661
719
|
|
662
|
-
|
663
|
-
|
720
|
+
signature = inspect.signature(original)
|
721
|
+
|
722
|
+
def patched_backward(*args, **kwargs):
|
723
|
+
new_kwargs = signature.bind(*args)
|
724
|
+
new_kwargs.apply_defaults()
|
725
|
+
new_kwargs = new_kwargs.arguments
|
726
|
+
new_kwargs.update(kwargs)
|
727
|
+
new_kwargs["create_graph"] = True
|
728
|
+
return original(**new_kwargs)
|
729
|
+
|
730
|
+
module.backward = patched_backward
|
731
|
+
yield
|
732
|
+
module.backward = original
|
733
|
+
|
734
|
+
with _inner(torch.Tensor), _inner(torch.autograd):
|
735
|
+
yield
|
664
736
|
|
665
|
-
Returns:
|
666
|
-
The return value of the modified closure.
|
667
|
-
"""
|
668
737
|
|
669
|
-
|
670
|
-
|
671
|
-
return original_backward(self, *args, **kwargs)
|
738
|
+
def hasattr_none(obj, name):
|
739
|
+
return getattr(obj, name, None) is not None
|
672
740
|
|
673
|
-
original_backward = torch.Tensor.backward
|
674
741
|
|
675
|
-
|
676
|
-
|
742
|
+
class ExactHVPFailed(ValueError):
|
743
|
+
pass
|
677
744
|
|
678
745
|
|
679
746
|
class StatefulOptimizer(torch.optim.Optimizer):
|
@@ -683,18 +750,22 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
683
750
|
The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
|
684
751
|
Further notice that both methods have different numerics outputs
|
685
752
|
"""
|
753
|
+
|
686
754
|
ema_decay: float = 0.001
|
687
755
|
compile_step: bool = False
|
688
756
|
hessian_approx: bool = False
|
689
757
|
precond_schedule: Union[Callable, float, None] = None
|
690
758
|
stochastic_schedule: bool = False
|
691
759
|
finite_differences: bool = False
|
760
|
+
fallback_to_finite_differences: bool = True
|
761
|
+
_fallback_enabled: bool = False
|
762
|
+
hvp_interval: int = 1 # grad is faster initially, hvp later
|
692
763
|
|
693
764
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
694
|
-
super().__init__(params, {**defaults,
|
765
|
+
super().__init__(params, {**defaults, "foreach": foreach})
|
695
766
|
self.use_ema = use_ema
|
696
767
|
self.mapping = {}
|
697
|
-
self._inner_group = {
|
768
|
+
self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
|
698
769
|
self._precond_rng = random.Random(0x12312)
|
699
770
|
self._is_preconditioning = None
|
700
771
|
|
@@ -710,34 +781,51 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
710
781
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
711
782
|
for p, g in zip(p_list, g_list):
|
712
783
|
state = self.state_(p)
|
713
|
-
if
|
714
|
-
state[
|
715
|
-
old_gs = [self.state_(p)[
|
784
|
+
if "mars_old_grad" not in state:
|
785
|
+
state["mars_old_grad"] = torch.zeros_like(g)
|
786
|
+
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
716
787
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
717
788
|
|
718
|
-
def split_p_and_g_in_group(
|
719
|
-
|
789
|
+
def split_p_and_g_in_group(
|
790
|
+
self,
|
791
|
+
group: dict,
|
792
|
+
skip_none: bool = True,
|
793
|
+
should_promote: bool = True,
|
794
|
+
beta1: float = -1.0,
|
795
|
+
raw: bool = False,
|
796
|
+
):
|
720
797
|
for p in group["params"]:
|
798
|
+
grad = getattr(p, "grad", None)
|
799
|
+
if grad is None and skip_none:
|
800
|
+
continue
|
801
|
+
|
802
|
+
p.grad = None
|
803
|
+
|
804
|
+
if raw:
|
805
|
+
yield p, grad
|
806
|
+
continue
|
807
|
+
|
721
808
|
if p in self.mapping:
|
722
809
|
p_views = self.mapping[p]
|
723
810
|
else:
|
724
811
|
self.mapping[p] = p_views = merge_group(group, p)
|
725
812
|
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
813
|
+
vector = getattr(p, "vector", None)
|
814
|
+
hessian_vector = getattr(p, "hessian_vector", None)
|
815
|
+
p.vector = None
|
816
|
+
p.hessian_vector = None
|
817
|
+
|
818
|
+
grad, vs, hvs = [
|
819
|
+
[None] * len(p_views) if x is None else merge_group(group, x) #
|
820
|
+
for x in (grad, vector, hessian_vector)
|
821
|
+
]
|
822
|
+
|
823
|
+
for pv, g, v, hv in zip(p_views, grad, vs, hvs):
|
824
|
+
g = promote_detach(g, should_promote)
|
825
|
+
if beta1 >= 0 and group.get("mars", False):
|
826
|
+
self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
|
827
|
+
pv.vector = promote_detach(v, should_promote)
|
828
|
+
pv.hessian_vector = promote_detach(hv, should_promote)
|
741
829
|
yield pv, g
|
742
830
|
|
743
831
|
def state_size(self) -> int:
|
@@ -759,48 +847,108 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
759
847
|
def ema_update(self):
|
760
848
|
with torch.no_grad():
|
761
849
|
for group in self.param_groups:
|
762
|
-
active_p = [p for p in group[
|
850
|
+
active_p = [p for p in group["params"]]
|
763
851
|
|
764
852
|
if not active_p:
|
765
853
|
return
|
766
854
|
|
767
|
-
k = group[
|
855
|
+
k = group["ema_step"] = group.get("ema_step", -1) + 1
|
768
856
|
|
769
857
|
for p in active_p:
|
770
|
-
if
|
771
|
-
self.state_(p)[
|
858
|
+
if "param_ema" not in self.state_(p):
|
859
|
+
self.state_(p)["param_ema"] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
772
860
|
|
773
|
-
y, param_ema = zip(*[(p.data, self.state_(p)[
|
861
|
+
y, param_ema = zip(*[(p.data, self.state_(p)["param_ema"]) for p in active_p])
|
774
862
|
torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
|
775
863
|
|
776
864
|
def copy_emas_to_params(self):
|
777
865
|
with torch.no_grad():
|
778
866
|
for group in self.param_groups:
|
779
|
-
active_p = [p for p in group[
|
867
|
+
active_p = [p for p in group["params"]]
|
780
868
|
|
781
869
|
if not active_p:
|
782
870
|
return
|
783
871
|
|
784
872
|
for p in active_p:
|
785
|
-
if
|
873
|
+
if "param_ema" in self.state_(p):
|
786
874
|
p_clone = p.data.clone()
|
787
|
-
set_(p.data, self.state_(p)[
|
788
|
-
set_(self.state_(p)[
|
875
|
+
set_(p.data, self.state_(p)["param_ema"])
|
876
|
+
set_(self.state_(p)["param_ema"], p_clone)
|
789
877
|
|
790
878
|
def copy_params_to_emas(self):
|
791
879
|
with torch.no_grad():
|
792
880
|
for group in self.param_groups:
|
793
|
-
active_p = [p for p in group[
|
881
|
+
active_p = [p for p in group["params"]]
|
794
882
|
|
795
883
|
if not active_p:
|
796
884
|
return
|
797
885
|
|
798
886
|
for p in active_p:
|
799
|
-
if
|
800
|
-
ema_clone = self.state_(p)[
|
801
|
-
set_(self.state_(p)[
|
887
|
+
if "param_ema" in self.state_(p):
|
888
|
+
ema_clone = self.state_(p)["param_ema"].data.clone()
|
889
|
+
set_(self.state_(p)["param_ema"], p.data)
|
802
890
|
set_(p.data, ema_clone)
|
803
891
|
|
892
|
+
def _finite_differences_hvp(self, closure):
|
893
|
+
with torch.enable_grad():
|
894
|
+
loss = closure() # closure without retain_graph=True
|
895
|
+
|
896
|
+
grads = []
|
897
|
+
for group in self.param_groups:
|
898
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
899
|
+
grads.append(g)
|
900
|
+
p.vector = torch.randn_like(p)
|
901
|
+
p.orig = p.data.clone()
|
902
|
+
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
|
903
|
+
stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
|
904
|
+
|
905
|
+
with torch.enable_grad():
|
906
|
+
closure()
|
907
|
+
|
908
|
+
# we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
|
909
|
+
# this costs more memory, but the imprecision seems too severe to use the other method
|
910
|
+
for group in self.param_groups:
|
911
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
912
|
+
p.grad = grads.pop(0)
|
913
|
+
stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
|
914
|
+
p.hessian_vector = g
|
915
|
+
p.data.copy_(p.orig)
|
916
|
+
del p.orig
|
917
|
+
return loss
|
918
|
+
|
919
|
+
def _double_backward_hvp(self, closure):
|
920
|
+
with torch.enable_grad(), patch_backward():
|
921
|
+
loss = closure()
|
922
|
+
|
923
|
+
params, grads = [], []
|
924
|
+
for group in self.param_groups:
|
925
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
926
|
+
params.append(p)
|
927
|
+
grads.append(g)
|
928
|
+
|
929
|
+
if not params:
|
930
|
+
raise ValueError("No parameter has gradients")
|
931
|
+
|
932
|
+
vs = [torch.randn_like(p) for p in params]
|
933
|
+
with torch.enable_grad():
|
934
|
+
try:
|
935
|
+
hvs = torch.autograd.grad(grads, params, vs, create_graph=False, retain_graph=False, allow_unused=True)
|
936
|
+
except RuntimeError as e:
|
937
|
+
raise ExactHVPFailed(str(e.args))
|
938
|
+
|
939
|
+
unused = []
|
940
|
+
for p, g, v, hv in zip(params, grads, vs, hvs):
|
941
|
+
p.hessian_vector = detach(hv)
|
942
|
+
p.grad = detach(g)
|
943
|
+
p.vector = detach(v)
|
944
|
+
if hv is None:
|
945
|
+
unused.append(list(p.shape))
|
946
|
+
|
947
|
+
if unused:
|
948
|
+
raise ExactHVPFailed(f"Parameters with the following shapes have no 2nd order derivative: {unused}")
|
949
|
+
|
950
|
+
return loss
|
951
|
+
|
804
952
|
def _handle_closure(self, closure):
|
805
953
|
hessian_approx = self.hessian_approx and self._is_preconditioning
|
806
954
|
|
@@ -809,53 +957,41 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
809
957
|
raise ValueError("Hessian approximation requires a closure.")
|
810
958
|
return None
|
811
959
|
|
812
|
-
|
960
|
+
step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
|
961
|
+
if not hessian_approx or step % self.hvp_interval == 0:
|
813
962
|
with torch.enable_grad():
|
814
963
|
loss = closure()
|
815
964
|
return loss
|
816
965
|
|
817
|
-
if self.finite_differences:
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
p.grad = g
|
847
|
-
params, grads = zip(*[x for group in self.param_groups for x in
|
848
|
-
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
849
|
-
vs = [torch.randn_like(p) for p in params]
|
850
|
-
with torch.enable_grad():
|
851
|
-
hvs = torch.autograd.grad(grads, params, vs)
|
852
|
-
|
853
|
-
for p, g, v, hv in zip(params, grads, vs, hvs):
|
854
|
-
p.hessian_vector = hv
|
855
|
-
p.grad = g
|
856
|
-
p.vector = v
|
857
|
-
|
858
|
-
return loss
|
966
|
+
if self.finite_differences or self._fallback_enabled:
|
967
|
+
return self._finite_differences_hvp(closure)
|
968
|
+
|
969
|
+
try:
|
970
|
+
return self._double_backward_hvp(closure)
|
971
|
+
except NotImplementedError as e:
|
972
|
+
if not self.fallback_to_finite_differences:
|
973
|
+
raise
|
974
|
+
if not any(isinstance(arg, str) and _cudnn_double_backward_pattern.match(arg) for arg in e.args):
|
975
|
+
raise
|
976
|
+
warn_once(
|
977
|
+
"CUDNN doesn't support double-backward for some models (including RNNs). " #
|
978
|
+
f"Falling back to finite_differences.\n{_fd_error}{e}"
|
979
|
+
)
|
980
|
+
except RuntimeError as e:
|
981
|
+
if not self.fallback_to_finite_differences:
|
982
|
+
raise
|
983
|
+
if not any(isinstance(arg, str) and _torch_compile_double_backward_pattern.match(arg) for arg in e.args):
|
984
|
+
raise
|
985
|
+
warn_once(
|
986
|
+
f"torch.compile does not support double-backward. Disabling it may be beneficial, depending on "
|
987
|
+
f"the model.\n{_fd_error}{e}"
|
988
|
+
)
|
989
|
+
except ExactHVPFailed as e:
|
990
|
+
if not self.fallback_to_finite_differences:
|
991
|
+
raise
|
992
|
+
warn_once(f"Exact HVP calculation failed.\n{_fd_error}{e}")
|
993
|
+
self._fallback_enabled = True
|
994
|
+
return self._handle_closure(closure)
|
859
995
|
|
860
996
|
def step(self, closure: Optional[Callable] = None):
|
861
997
|
if self.precond_schedule is None:
|
@@ -867,11 +1003,15 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
867
1003
|
# we assume that parameters are constant and that there are no excessive recompiles
|
868
1004
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
869
1005
|
for group in self.param_groups:
|
870
|
-
group[
|
1006
|
+
group["is_preconditioning"] = self._is_preconditioning
|
871
1007
|
self._step(group)
|
872
1008
|
if self.use_ema:
|
873
1009
|
self.ema_update()
|
874
|
-
|
1010
|
+
for real, views in self.mapping.items():
|
1011
|
+
for tensor in (real, *views):
|
1012
|
+
for key in ("grad", "vector", "hessian_vector", "orig"):
|
1013
|
+
if hasattr(tensor, key):
|
1014
|
+
setattr(tensor, key, None)
|
875
1015
|
return loss
|
876
1016
|
|
877
1017
|
|
@@ -891,8 +1031,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
|
891
1031
|
|
892
1032
|
|
893
1033
|
@decorator_knowngood
|
894
|
-
def _compilable_adam_(
|
895
|
-
|
1034
|
+
def _compilable_adam_(
|
1035
|
+
exp_avg: List[Tensor],
|
1036
|
+
exp_avg_sq: List[Tensor],
|
1037
|
+
grad: List[Tensor],
|
1038
|
+
beta1: Tensor,
|
1039
|
+
beta2: Tensor,
|
1040
|
+
step: Tensor,
|
1041
|
+
eps: Tensor,
|
1042
|
+
):
|
896
1043
|
beta1 = beta_debias(beta1, step)
|
897
1044
|
beta2 = beta_debias(beta2, step)
|
898
1045
|
|
@@ -903,8 +1050,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
903
1050
|
copy_stochastic_list_(grad, u32)
|
904
1051
|
|
905
1052
|
|
906
|
-
def adam_(
|
907
|
-
|
1053
|
+
def adam_(
|
1054
|
+
exp_avg: List[Tensor],
|
1055
|
+
exp_avg_sq: List[Tensor],
|
1056
|
+
grad: List[Tensor],
|
1057
|
+
beta1: float,
|
1058
|
+
beta2: float,
|
1059
|
+
step: int,
|
1060
|
+
eps: float = 1e-8,
|
1061
|
+
):
|
908
1062
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
909
1063
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
910
1064
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -912,9 +1066,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
|
|
912
1066
|
|
913
1067
|
|
914
1068
|
@decorator_knowngood
|
915
|
-
def _fused_compilable_adam_(
|
916
|
-
|
917
|
-
|
1069
|
+
def _fused_compilable_adam_(
|
1070
|
+
y: List[Tensor],
|
1071
|
+
exp_avg: List[Tensor],
|
1072
|
+
exp_avg_sq: List[Tensor],
|
1073
|
+
update: List[Tensor],
|
1074
|
+
grad: List[Tensor],
|
1075
|
+
beta1: Tensor,
|
1076
|
+
beta2: Tensor,
|
1077
|
+
step: Tensor,
|
1078
|
+
decay: Tensor,
|
1079
|
+
lr: Tensor,
|
1080
|
+
eps: Tensor,
|
1081
|
+
caution: bool,
|
1082
|
+
):
|
918
1083
|
beta1 = beta_debias(beta1, step)
|
919
1084
|
beta2 = beta_debias(beta2, step)
|
920
1085
|
|
@@ -925,17 +1090,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
925
1090
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
926
1091
|
|
927
1092
|
|
928
|
-
def fused_adam_(
|
929
|
-
|
930
|
-
|
1093
|
+
def fused_adam_(
|
1094
|
+
y: List[Tensor],
|
1095
|
+
exp_avg: List[Tensor],
|
1096
|
+
exp_avg_sq: List[Tensor],
|
1097
|
+
update: List[Tensor],
|
1098
|
+
grad: List[Tensor],
|
1099
|
+
beta1: float,
|
1100
|
+
beta2: float,
|
1101
|
+
step: int,
|
1102
|
+
lr: float,
|
1103
|
+
eps: float,
|
1104
|
+
decay: float,
|
1105
|
+
caution: bool,
|
1106
|
+
):
|
931
1107
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
932
1108
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
933
1109
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
934
1110
|
|
935
1111
|
|
936
1112
|
@decorator_knowngood
|
937
|
-
def _compilable_laprop_(
|
938
|
-
|
1113
|
+
def _compilable_laprop_(
|
1114
|
+
exp_avg: List[Tensor],
|
1115
|
+
exp_avg_sq: List[Tensor],
|
1116
|
+
grad: List[Tensor],
|
1117
|
+
beta1: Tensor,
|
1118
|
+
beta2: Tensor,
|
1119
|
+
step: Tensor,
|
1120
|
+
eps: Tensor,
|
1121
|
+
):
|
939
1122
|
beta1 = beta_debias(beta1, step)
|
940
1123
|
beta2 = beta_debias(beta2, step)
|
941
1124
|
|
@@ -946,8 +1129,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
946
1129
|
copy_stochastic_list_(grad, gp32)
|
947
1130
|
|
948
1131
|
|
949
|
-
def laprop_(
|
950
|
-
|
1132
|
+
def laprop_(
|
1133
|
+
exp_avg: List[Tensor],
|
1134
|
+
exp_avg_sq: List[Tensor],
|
1135
|
+
grad: List[Tensor],
|
1136
|
+
beta1: float,
|
1137
|
+
beta2: float,
|
1138
|
+
step: int,
|
1139
|
+
eps: float = 1e-8,
|
1140
|
+
):
|
951
1141
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
952
1142
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
953
1143
|
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -955,9 +1145,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
955
1145
|
|
956
1146
|
|
957
1147
|
@decorator_knowngood
|
958
|
-
def _fused_compilable_laprop_(
|
959
|
-
|
960
|
-
|
1148
|
+
def _fused_compilable_laprop_(
|
1149
|
+
y: List[Tensor],
|
1150
|
+
exp_avg: List[Tensor],
|
1151
|
+
exp_avg_sq: List[Tensor],
|
1152
|
+
update: List[Tensor],
|
1153
|
+
grad: List[Tensor],
|
1154
|
+
beta1: Tensor,
|
1155
|
+
beta2: Tensor,
|
1156
|
+
step: Tensor,
|
1157
|
+
lr: Tensor,
|
1158
|
+
decay: Tensor,
|
1159
|
+
caution: bool,
|
1160
|
+
eps: Tensor,
|
1161
|
+
):
|
961
1162
|
beta1 = beta_debias(beta1, step)
|
962
1163
|
beta2 = beta_debias(beta2, step)
|
963
1164
|
|
@@ -968,9 +1169,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
968
1169
|
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
969
1170
|
|
970
1171
|
|
971
|
-
def fused_laprop_(
|
972
|
-
|
973
|
-
|
1172
|
+
def fused_laprop_(
|
1173
|
+
y: List[Tensor],
|
1174
|
+
exp_avg: List[Tensor],
|
1175
|
+
exp_avg_sq: List[Tensor],
|
1176
|
+
update: List[Tensor],
|
1177
|
+
grad: List[Tensor],
|
1178
|
+
beta1: float,
|
1179
|
+
beta2: float,
|
1180
|
+
step: int,
|
1181
|
+
lr: float,
|
1182
|
+
decay: float,
|
1183
|
+
caution: bool,
|
1184
|
+
eps: float = 1e-8,
|
1185
|
+
):
|
974
1186
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
975
1187
|
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
976
1188
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
@@ -978,7 +1190,7 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
|
|
978
1190
|
|
979
1191
|
@decorator_knowngood
|
980
1192
|
def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
981
|
-
u32, g32, exp_avg_sq32
|
1193
|
+
u32, g32, exp_avg_sq32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq]]
|
982
1194
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
983
1195
|
|
984
1196
|
beta1 = beta_debias(beta1, step)
|
@@ -997,7 +1209,7 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
|
|
997
1209
|
|
998
1210
|
@decorator_knowngood
|
999
1211
|
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
|
1000
|
-
g32,
|
1212
|
+
g32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
|
1001
1213
|
update = [e.clone() for e in exp_avg]
|
1002
1214
|
|
1003
1215
|
beta1 = beta_debias(beta1, step)
|
@@ -1044,8 +1256,9 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
1044
1256
|
|
1045
1257
|
|
1046
1258
|
@decorator_knowngood
|
1047
|
-
def _compilable_update_(
|
1048
|
-
|
1259
|
+
def _compilable_update_(
|
1260
|
+
p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]]
|
1261
|
+
):
|
1049
1262
|
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
1050
1263
|
u_ = promote(u_.view_as(p_))
|
1051
1264
|
p32_ = promote(p_)
|
@@ -1055,8 +1268,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
|
|
1055
1268
|
copy_stochastic_(p_, p32_)
|
1056
1269
|
|
1057
1270
|
|
1058
|
-
def update_param_(
|
1059
|
-
|
1271
|
+
def update_param_(
|
1272
|
+
param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None
|
1273
|
+
):
|
1060
1274
|
param, update, grad = list_guard(param, update, grad)
|
1061
1275
|
lr = scalar_guard(lr, param[0])
|
1062
1276
|
if not caution:
|
@@ -1064,38 +1278,117 @@ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: f
|
|
1064
1278
|
_compilable_update_(param, update, decay, lr, caution, grad)
|
1065
1279
|
|
1066
1280
|
|
1067
|
-
def precond_schedule(step, precond_scheduler
|
1281
|
+
def precond_schedule(step, precond_scheduler):
|
1068
1282
|
precond_prob = max(step, 1) ** precond_scheduler[0]
|
1069
1283
|
precond_prob = math.log10(precond_prob)
|
1070
1284
|
precond_prob = precond_prob ** precond_scheduler[1] + 1
|
1071
|
-
|
1072
|
-
update_precond = rng.random() < precond_prob
|
1073
|
-
return update_precond
|
1285
|
+
return 1 / precond_prob
|
1074
1286
|
|
1075
1287
|
|
1076
1288
|
def get_soap_precond_schedule(precond_scheduler):
|
1077
|
-
|
1078
|
-
|
1079
|
-
def _inner(step):
|
1080
|
-
return precond_schedule(step, precond_scheduler, rng)
|
1081
|
-
|
1082
|
-
return _inner
|
1289
|
+
return functools.partial(precond_schedule, precond_scheduler=precond_scheduler)
|
1083
1290
|
|
1084
1291
|
|
1085
1292
|
def _max_idx(x: List[int]):
|
1086
1293
|
return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
|
1087
1294
|
|
1088
1295
|
|
1089
|
-
|
1090
|
-
|
1296
|
+
@decorator_knowngood
|
1297
|
+
def stable_exp(x: Tensor):
|
1298
|
+
# fp16:
|
1299
|
+
# exp(x) is stable in [-17, 11]
|
1300
|
+
# `stable_exp` extends to [-17, 17]
|
1301
|
+
# average error (in [-10, 10]) increased from 2.288e-3 to 2.299e-3
|
1302
|
+
# fp32:
|
1303
|
+
# exp(x) is stable in [-103, 88]
|
1304
|
+
# `stable_exp` extends to [-103, 103]
|
1305
|
+
# average error (in [-87, 87]) reduced from 3.309-06 to 3.224-06
|
1306
|
+
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1307
|
+
|
1308
|
+
|
1309
|
+
@decorator_knowngood
|
1310
|
+
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
|
+
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
+
log_x = x.double().abs().clamp(min=eps).log()
|
1313
|
+
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
+
return stable_exp(-log_mean_x_pow / pow / 2)
|
1315
|
+
|
1316
|
+
|
1317
|
+
@decorator_knowngood
|
1318
|
+
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
|
+
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
+
log_x = x.double().abs().clamp(min=eps).log()
|
1321
|
+
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
+
|
1323
|
+
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
+
x_normed = x_normed / pow0 / 2
|
1325
|
+
|
1326
|
+
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
+
y_normed = y_normed / pow1 / 2
|
1328
|
+
|
1329
|
+
return stable_exp(x_normed - y_normed)
|
1330
|
+
|
1331
|
+
|
1332
|
+
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
|
1333
|
+
automatic_scale = True
|
1334
|
+
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1335
|
+
if scale is not None:
|
1336
|
+
automatic_scale = False
|
1337
|
+
warn_once(
|
1338
|
+
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1339
|
+
)
|
1340
|
+
if scale_scale is not None and scale_scale != 1:
|
1341
|
+
warn_once(
|
1342
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
|
1343
|
+
)
|
1344
|
+
elif hessian_vector is None:
|
1345
|
+
scale = mean_root(grad, 4) * scale_scale
|
1346
|
+
else:
|
1347
|
+
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1348
|
+
if isinstance(scale, torch.Tensor):
|
1349
|
+
scale = scale.item() # slow, but necessary
|
1350
|
+
if np.isfinite(scale):
|
1351
|
+
if scale > scale_max or scale < 1 / scale_max:
|
1352
|
+
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1353
|
+
return scale
|
1354
|
+
if not automatic_scale:
|
1355
|
+
raise ValueError("The manually set precond_init_scale is not finite")
|
1356
|
+
|
1357
|
+
for x in (grad, hessian_vector, vector):
|
1358
|
+
if x is None:
|
1359
|
+
continue
|
1360
|
+
if torch.allclose(x, torch.zeros_like(x)).item():
|
1361
|
+
raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
|
1362
|
+
if not torch.isfinite(x).all().item():
|
1363
|
+
raise ValueError("Grad or HVP is not finite")
|
1364
|
+
raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1365
|
+
|
1366
|
+
|
1367
|
+
def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
|
1368
|
+
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1369
|
+
U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
|
1370
|
+
V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
|
1371
|
+
d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
|
1372
|
+
return U, V, d
|
1373
|
+
|
1374
|
+
|
1375
|
+
def init_Q_exprs(
|
1376
|
+
grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
|
1377
|
+
):
|
1378
|
+
"""
|
1379
|
+
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
1091
1380
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
1381
|
+
|
1382
|
+
precond init scale computation from
|
1383
|
+
https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
|
1092
1384
|
"""
|
1385
|
+
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1093
1386
|
letters = string.ascii_lowercase + string.ascii_uppercase
|
1094
|
-
dtype = dtype if dtype is not None else
|
1095
|
-
shape =
|
1387
|
+
dtype = dtype if dtype is not None else grad.dtype
|
1388
|
+
shape = grad.shape
|
1096
1389
|
|
1097
1390
|
if len(shape) == 0: # scalar
|
1098
|
-
Q = [scale * torch.ones_like(
|
1391
|
+
Q = [scale * torch.ones_like(grad, dtype=dtype)]
|
1099
1392
|
exprA = ",->"
|
1100
1393
|
exprGs = [",->"]
|
1101
1394
|
exprP = ",,->"
|
@@ -1103,7 +1396,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1103
1396
|
|
1104
1397
|
# Tensor
|
1105
1398
|
if len(shape) > 13:
|
1106
|
-
raise ValueError(f"Got tensor with dim {len(
|
1399
|
+
raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
|
1107
1400
|
|
1108
1401
|
scale = scale ** (1 / len(shape))
|
1109
1402
|
|
@@ -1119,8 +1412,10 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1119
1412
|
elif memory_save_mode == "all_diag":
|
1120
1413
|
dim_diag = [True for _ in shape]
|
1121
1414
|
else:
|
1122
|
-
raise ValueError(
|
1123
|
-
|
1415
|
+
raise ValueError(
|
1416
|
+
f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
|
1417
|
+
"[None, 'one_diag', 'all_diag', 'smart_one_diag']"
|
1418
|
+
)
|
1124
1419
|
|
1125
1420
|
Q = []
|
1126
1421
|
piece1A, piece2A, piece3A = ([], "", "")
|
@@ -1129,7 +1424,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1129
1424
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
1130
1425
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
1131
1426
|
# use diagonal matrix as preconditioner for this dim
|
1132
|
-
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=
|
1427
|
+
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
|
1133
1428
|
|
1134
1429
|
piece1A.append(letters[i])
|
1135
1430
|
piece2A = piece2A + letters[i]
|
@@ -1143,13 +1438,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1143
1438
|
piece4P = piece4P + letters[i + 13]
|
1144
1439
|
else:
|
1145
1440
|
# use triangular matrix as preconditioner for this dim
|
1146
|
-
Q.append(scale * torch.eye(size, dtype=dtype, device=
|
1441
|
+
Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
|
1147
1442
|
piece1A.append(letters[i] + letters[i + 13])
|
1148
1443
|
piece2A = piece2A + letters[i + 13]
|
1149
1444
|
piece3A = piece3A + letters[i]
|
1150
1445
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1151
1446
|
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1152
|
-
subscripts =
|
1447
|
+
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
|
1153
1448
|
exprGs.append(subscripts)
|
1154
1449
|
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1155
1450
|
piece1P.append(a + b)
|
@@ -1158,7 +1453,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1158
1453
|
piece4P = piece4P + b
|
1159
1454
|
|
1160
1455
|
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
1161
|
-
exprP =
|
1456
|
+
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
1162
1457
|
return [Q, (exprA, tuple(exprGs), exprP)]
|
1163
1458
|
|
1164
1459
|
|
@@ -1170,37 +1465,207 @@ def psgd_balance_Q(Q_in):
|
|
1170
1465
|
torch._foreach_mul_(Q_in, list(norms))
|
1171
1466
|
|
1172
1467
|
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1468
|
+
@decorator
|
1469
|
+
def psgd_balance_lra(U: Tensor, V: Tensor):
|
1470
|
+
u_norm = promote(torch.linalg.vector_norm(U))
|
1471
|
+
v_norm = promote(torch.linalg.vector_norm(V))
|
1472
|
+
scale = (u_norm / v_norm) ** 0.5
|
1473
|
+
U.div_(scale)
|
1474
|
+
V.mul_(scale)
|
1475
|
+
|
1476
|
+
|
1477
|
+
@decorator
|
1478
|
+
def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
1479
|
+
dtype = min_dtype([U, V, x])
|
1480
|
+
return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1481
|
+
|
1482
|
+
|
1483
|
+
def update_lra_precond_(
|
1484
|
+
U: List[Tensor],
|
1485
|
+
V: List[Tensor],
|
1486
|
+
d: List[Tensor],
|
1487
|
+
vector: Tensor,
|
1488
|
+
hessian_vector: Tensor,
|
1489
|
+
eps: float,
|
1490
|
+
step: float,
|
1491
|
+
delayed: bool,
|
1492
|
+
):
|
1493
|
+
"""
|
1494
|
+
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1495
|
+
"""
|
1496
|
+
U_orig, V_orig, d_orig = U, V, d
|
1497
|
+
|
1498
|
+
U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
|
1499
|
+
|
1500
|
+
dtype = min_dtype([U, V, vector, hessian_vector])
|
1501
|
+
U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
|
1502
|
+
|
1503
|
+
eps = scalar_guard(eps, vector)
|
1504
|
+
|
1505
|
+
Qh = low_rank_mm(U, V, d * hessian_vector)
|
1506
|
+
Ph = d * low_rank_mm(V, U, Qh)
|
1507
|
+
rank = U.size(1)
|
1508
|
+
|
1509
|
+
VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
|
1510
|
+
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
1511
|
+
IpVtU = I + VtU
|
1512
|
+
invQtv = vector / d
|
1513
|
+
|
1514
|
+
# LU factorization to reuse computation
|
1515
|
+
try:
|
1516
|
+
LU, pivots = torch.linalg.lu_factor(IpVtU)
|
1517
|
+
except RuntimeError:
|
1518
|
+
# Error:
|
1519
|
+
# U[2,2] is zero and using it on lu_solve would result in a division by zero.
|
1520
|
+
# If you still want to perform the factorization, consider calling
|
1521
|
+
# linalg.lu(A, pivot) or linalg.lu_factor_ex(A, pivot)
|
1522
|
+
# ---
|
1523
|
+
# So, we skip this step and reattempt on the next one
|
1524
|
+
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1525
|
+
|
1526
|
+
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
1527
|
+
invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
1528
|
+
invPv = invPv / d
|
1529
|
+
|
1530
|
+
nablaD = Ph * hessian_vector - vector * invPv
|
1531
|
+
divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
|
1532
|
+
divisor = divisor.add(eps).sqrt().max()
|
1533
|
+
d_step = step / divisor
|
1534
|
+
|
1535
|
+
apply_flat_add(d_orig, d * nablaD, -d_step)
|
1536
|
+
|
1537
|
+
a, b = Qh, invQtv
|
1538
|
+
|
1539
|
+
precond_u = random.random() < 0.5 # update either U or V, not both at the same time
|
1540
|
+
precond = V if precond_u else U
|
1541
|
+
atV = torch.einsum("b,br->r", a, precond) # o == one
|
1542
|
+
btV = torch.einsum("b,br->r", b, precond)
|
1543
|
+
atVVt = torch.einsum("r,br->b", atV, precond)
|
1544
|
+
btVVt = torch.einsum("r,br->b", btV, precond)
|
1545
|
+
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
|
1546
|
+
if precond_u:
|
1547
|
+
a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
|
1548
|
+
b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
|
1182
1549
|
else:
|
1183
|
-
|
1184
|
-
|
1550
|
+
a = a + torch.einsum("br,r->b", V, atV)
|
1551
|
+
b = b + torch.einsum("br,r->b", V, btV)
|
1552
|
+
a = torch.einsum("b,r->br", a, atV)
|
1553
|
+
b = torch.einsum("b,r->br", b, btV)
|
1554
|
+
apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
|
1555
|
+
|
1556
|
+
if not delayed:
|
1557
|
+
stochastic_add_([d], [d * nablaD], -d_step)
|
1558
|
+
stochastic_add_([U if precond_u else V], [b - a], precond_step)
|
1559
|
+
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1560
|
+
|
1561
|
+
|
1562
|
+
def lra_precond(U, V, d, g):
|
1563
|
+
"""
|
1564
|
+
As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
|
1565
|
+
"""
|
1566
|
+
g = low_rank_mm(U, V, d * g)
|
1567
|
+
return d * low_rank_mm(V, U, g)
|
1568
|
+
|
1569
|
+
|
1570
|
+
@decorator_knowngood
|
1571
|
+
def dampen_grad(g: Tensor, damp: float = 2**-13):
|
1572
|
+
# https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
|
1573
|
+
v = torch.randn_like(g)
|
1574
|
+
return v, g + damp * g.abs().mean() * v
|
1575
|
+
|
1576
|
+
|
1577
|
+
@decorator_knowngood
|
1578
|
+
def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
|
1579
|
+
update = lra_precond(U, V, d, update)
|
1580
|
+
start = 0
|
1581
|
+
update = update.flatten()
|
1582
|
+
for p in params:
|
1583
|
+
size = p.numel()
|
1584
|
+
copy_stochastic_(p, update[start : start + size].view_as(p))
|
1585
|
+
start += size
|
1586
|
+
|
1587
|
+
|
1588
|
+
@decorator_knowngood
|
1589
|
+
def apply_flat_update(params: List[Tensor], update: Tensor):
|
1590
|
+
start = 0
|
1591
|
+
update = update.flatten()
|
1592
|
+
for p in params:
|
1593
|
+
size = p.numel()
|
1594
|
+
copy_stochastic_(p, update[start : start + size].view_as(p))
|
1595
|
+
start += size
|
1596
|
+
|
1597
|
+
|
1598
|
+
@decorator_knowngood
|
1599
|
+
def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
1600
|
+
start = 0
|
1601
|
+
update = update.flatten()
|
1602
|
+
for p in params:
|
1603
|
+
size = p.numel()
|
1604
|
+
stochastic_add_([p], [update[start : start + size].view_as(p)], alpha)
|
1605
|
+
start += size
|
1606
|
+
|
1607
|
+
|
1608
|
+
@decorator_knowngood
|
1609
|
+
def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
1610
|
+
start = 0
|
1611
|
+
outputs = []
|
1612
|
+
update = update.flatten()
|
1613
|
+
for p in params:
|
1614
|
+
size = p.numel()
|
1615
|
+
outputs.append(update[start : start + size].view_as(p))
|
1616
|
+
start += size
|
1617
|
+
return outputs
|
1618
|
+
|
1619
|
+
|
1620
|
+
@decorator_knowngood
|
1621
|
+
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1622
|
+
last_dim = x[0].shape[-remaining:] if remaining else []
|
1623
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1624
|
+
|
1625
|
+
|
1626
|
+
@decorator_knowngood
|
1627
|
+
def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
1628
|
+
vs = []
|
1629
|
+
gs = []
|
1630
|
+
for g_ in g:
|
1631
|
+
v, g = dampen_grad(g_, damp)
|
1632
|
+
vs.append(v)
|
1633
|
+
gs.append(g)
|
1634
|
+
return flatten(vs), flatten(gs)
|
1635
|
+
|
1636
|
+
|
1637
|
+
@decorator_knowngood
|
1638
|
+
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1639
|
+
md = min_dtype(args)
|
1640
|
+
return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1641
|
+
|
1642
|
+
|
1643
|
+
def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
1644
|
+
order = G.dim()
|
1645
|
+
if order > 1:
|
1646
|
+
conjB = conjB.view_as(G).permute(*range(1, order), 0)
|
1647
|
+
conjB = conjB.to(promote(G.dtype))
|
1648
|
+
A = casted_einsum(exprA, *Q, G)
|
1185
1649
|
for i, q in enumerate(Q):
|
1650
|
+
q = promote(q)
|
1186
1651
|
if q.dim() <= 1:
|
1187
1652
|
conjB /= q
|
1188
1653
|
else:
|
1189
|
-
|
1190
|
-
|
1654
|
+
solved = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)).contiguous(), upper=True, left=False)
|
1655
|
+
conjB = solved.reshape_as(conjB)
|
1191
1656
|
if i < order - 1:
|
1192
|
-
conjB =
|
1657
|
+
conjB = conjB.transpose(i, -1)
|
1193
1658
|
return A, conjB
|
1194
1659
|
|
1195
1660
|
|
1196
1661
|
def psgd_lb(A, max_abs):
|
1197
1662
|
A /= max_abs
|
1198
|
-
a0 = torch.einsum(
|
1663
|
+
a0 = torch.einsum("ij,ij->j", A, A)
|
1199
1664
|
i = torch.argmax(a0)
|
1200
1665
|
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1201
|
-
x = torch.einsum(
|
1666
|
+
x = torch.einsum("i,ij->j", x, A)
|
1202
1667
|
x /= x.norm()
|
1203
|
-
x = torch.einsum(
|
1668
|
+
x = torch.einsum("j,kj->k", x, A)
|
1204
1669
|
x = x.norm()
|
1205
1670
|
x *= max_abs
|
1206
1671
|
return x
|
@@ -1217,7 +1682,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
|
1217
1682
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1218
1683
|
term1, term2 = term1 - term2, term1 + term2
|
1219
1684
|
term1 *= precond_lr
|
1220
|
-
norm = term2.norm(float(
|
1685
|
+
norm = term2.norm(float("inf"))
|
1221
1686
|
if q.dim() < 2:
|
1222
1687
|
term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
|
1223
1688
|
else:
|
@@ -1225,9 +1690,12 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
|
1225
1690
|
term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
|
1226
1691
|
term1 = torch.mm(term1, q.to(term1.dtype))
|
1227
1692
|
if store_triu_as_line:
|
1228
|
-
term1 = triu_to_line([term1])[0][1]
|
1229
|
-
|
1230
|
-
|
1693
|
+
term1 = triu_to_line([term1])[0][1] # Convert update to line format
|
1694
|
+
# Apply update directly to the tensor part of the state tuple o[1]
|
1695
|
+
stochastic_add_(o[1], term1, -1)
|
1696
|
+
else:
|
1697
|
+
# Apply update to the state tensor o
|
1698
|
+
stochastic_add_(o, term1, -1)
|
1231
1699
|
|
1232
1700
|
|
1233
1701
|
@decorator_knowngood
|
@@ -1245,7 +1713,7 @@ def l2_normalization_(x, clip_at: float = 1e-8):
|
|
1245
1713
|
return _compilable_l2_clip_(x, clip_at)
|
1246
1714
|
|
1247
1715
|
|
1248
|
-
def l2_clip_(x, clip_at: float = 1.):
|
1716
|
+
def l2_clip_(x, clip_at: float = 1.0):
|
1249
1717
|
x = list_guard(x)
|
1250
1718
|
return _compilable_l2_clip_(x, clip_at)
|
1251
1719
|
|
@@ -1437,12 +1905,13 @@ def warn_once(msg):
|
|
1437
1905
|
_warned.add(msg)
|
1438
1906
|
|
1439
1907
|
|
1440
|
-
def psgd_should_update(
|
1441
|
-
|
1442
|
-
|
1908
|
+
def psgd_should_update(
|
1909
|
+
group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
|
1910
|
+
):
|
1911
|
+
group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
|
1443
1912
|
if not isinstance(prob, float):
|
1444
|
-
prob = prob(group[f
|
1445
|
-
if group[
|
1913
|
+
prob = prob(group[f"{name}_prob_step"])
|
1914
|
+
if group["stochastic_schedule"]:
|
1446
1915
|
return rng.random() < prob
|
1447
1916
|
cumulative_prob = group.get(name, 0)
|
1448
1917
|
group[name] = cumulative_prob + prob
|
@@ -1450,8 +1919,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1450
1919
|
|
1451
1920
|
|
1452
1921
|
@decorator_knowngood
|
1453
|
-
def precond_grad_cached_(
|
1454
|
-
|
1922
|
+
def precond_grad_cached_(
|
1923
|
+
expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
|
1924
|
+
):
|
1455
1925
|
if caution:
|
1456
1926
|
ea = _compilable_cautioning(grad, ea)
|
1457
1927
|
md = min_dtype(list(cached_q) + [ea])
|
@@ -1564,18 +2034,86 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
|
|
1564
2034
|
|
1565
2035
|
|
1566
2036
|
def merge_group(group, *tensors):
|
1567
|
-
if not group.get(
|
2037
|
+
if not group.get("merge_dims", False):
|
1568
2038
|
return tensors
|
1569
2039
|
if isinstance(tensors[0], list):
|
1570
2040
|
return [merge_group(group, *t) for t in tensors]
|
1571
2041
|
|
1572
2042
|
out = []
|
1573
2043
|
for t in tensors:
|
1574
|
-
append_or_extend(
|
1575
|
-
|
2044
|
+
append_or_extend(
|
2045
|
+
out,
|
2046
|
+
dim_merger(
|
2047
|
+
t,
|
2048
|
+
group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
|
2049
|
+
group.get("split", False),
|
2050
|
+
),
|
2051
|
+
)
|
1576
2052
|
return out
|
1577
2053
|
|
1578
2054
|
|
2055
|
+
@decorator_knowngood
|
2056
|
+
def _compilable_d_adapt_(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
|
2057
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2058
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2059
|
+
next_d = d * (g * s).sum()
|
2060
|
+
s = s + u * d
|
2061
|
+
next_d = next_d / s.abs().sum()
|
2062
|
+
next_d = torch.maximum(next_d, d)
|
2063
|
+
copy_stochastic_(u_, u * d)
|
2064
|
+
copy_stochastic_(d_, next_d)
|
2065
|
+
copy_stochastic_(s_, s)
|
2066
|
+
|
2067
|
+
|
2068
|
+
def d_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
|
2069
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2070
|
+
_compilable_d_adapt_(grads, update, state, delta)
|
2071
|
+
|
2072
|
+
|
2073
|
+
@decorator_knowngood
|
2074
|
+
def _compilable_lr_adapt_(
|
2075
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
|
2076
|
+
):
|
2077
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2078
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2079
|
+
lr_grad = d.sigmoid()
|
2080
|
+
lr_grad = lr_grad * (1 - lr_grad)
|
2081
|
+
lr_grad = lr_grad * (s * g).mean()
|
2082
|
+
d = d - lr_grad * lr_lr
|
2083
|
+
copy_stochastic_(d_, d)
|
2084
|
+
copy_stochastic_(u_, u * d.sigmoid())
|
2085
|
+
copy_stochastic_(s_, u)
|
2086
|
+
|
2087
|
+
|
2088
|
+
def lr_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float):
|
2089
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2090
|
+
lr_lr = scalar_guard(lr_lr, grads[0])
|
2091
|
+
_compilable_lr_adapt_(grads, update, state, delta, lr_lr)
|
2092
|
+
|
2093
|
+
|
2094
|
+
@decorator_knowngood
|
2095
|
+
def _compilable_pointwise_lr_adapt_(
|
2096
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
|
2097
|
+
):
|
2098
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2099
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2100
|
+
lr_grad = d.sigmoid()
|
2101
|
+
lr_grad = lr_grad * (1 - lr_grad)
|
2102
|
+
lr_grad = lr_grad * s * g
|
2103
|
+
d = d - lr_grad * lr_lr
|
2104
|
+
copy_stochastic_(d_, d)
|
2105
|
+
copy_stochastic_(u_, u * d.sigmoid())
|
2106
|
+
copy_stochastic_(s_, u)
|
2107
|
+
|
2108
|
+
|
2109
|
+
def pointwise_lr_adaptation(
|
2110
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float
|
2111
|
+
):
|
2112
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2113
|
+
lr_lr = scalar_guard(lr_lr, grads[0])
|
2114
|
+
_compilable_lr_adapt_(grads, update, state, delta, lr_lr)
|
2115
|
+
|
2116
|
+
|
1579
2117
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1580
2118
|
optimizers = {}
|
1581
2119
|
|
@@ -1598,8 +2136,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1598
2136
|
|
1599
2137
|
o = optimizer(parameters, *args, **kwargs)
|
1600
2138
|
step_fn = o.step
|
1601
|
-
o.step = functools.partial(
|
1602
|
-
|
2139
|
+
o.step = functools.partial(
|
2140
|
+
warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything."
|
2141
|
+
)
|
1603
2142
|
|
1604
2143
|
def _step(p: Tensor):
|
1605
2144
|
seen_params.add(p)
|