heavyball 1.7.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +276 -37
- heavyball/chainable.py +419 -206
- heavyball/helpers.py +808 -0
- heavyball/utils.py +1062 -315
- heavyball-2.0.0.dist-info/METADATA +122 -0
- heavyball-2.0.0.dist-info/RECORD +9 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/WHEEL +1 -1
- heavyball-1.7.2.dist-info/METADATA +0 -939
- heavyball-1.7.2.dist-info/RECORD +0 -8
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.2.dist-info → heavyball-2.0.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,28 +1,29 @@
|
|
1
|
+
import collections
|
1
2
|
import contextlib
|
2
3
|
import functools
|
3
4
|
import gc
|
4
5
|
import inspect
|
5
6
|
import math
|
7
|
+
import pickle
|
6
8
|
import random
|
7
9
|
import re
|
8
10
|
import string
|
9
11
|
import warnings
|
10
|
-
from typing import Callable, List, Optional, Tuple, Union
|
12
|
+
from typing import Callable, List, Literal, Optional, Tuple, Union
|
11
13
|
|
12
14
|
import numpy as np
|
13
15
|
import torch
|
14
16
|
from torch import Tensor
|
15
|
-
from torch._dynamo import config
|
16
17
|
from torch._dynamo.exc import TorchDynamoException
|
17
18
|
from torch.backends import cudnn, opt_einsum
|
19
|
+
from torch.nn import functional as F
|
18
20
|
from torch.utils._pytree import tree_map
|
19
21
|
|
20
|
-
config.cache_size_limit = 2**16
|
21
|
-
|
22
22
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
23
|
dynamic = False
|
24
24
|
compile_mode_recommended_to_none = None
|
25
|
-
zeroth_power_mode = "
|
25
|
+
zeroth_power_mode = "newtonschulz"
|
26
|
+
precise_zeroth_power_mode = "qr" # or svd
|
26
27
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
28
|
_cudnn_double_backward_pattern = re.compile(
|
28
29
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -68,6 +69,16 @@ def decorator_knowngood(func: Callable, fullgraph: bool = True):
|
|
68
69
|
einsum_base = string.ascii_lowercase
|
69
70
|
|
70
71
|
|
72
|
+
@decorator_knowngood
|
73
|
+
def compiled_einsum(expr, *args):
|
74
|
+
"""
|
75
|
+
this is necessary to avoid the slowdown introduced by uncompiled einsum
|
76
|
+
uncompiled einsum is twice as slow if we add three 1-sized dimensions
|
77
|
+
for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
|
78
|
+
"""
|
79
|
+
return torch.einsum(expr, *args)
|
80
|
+
|
81
|
+
|
71
82
|
@decorator_knowngood
|
72
83
|
def _compilable_schedule_free_(
|
73
84
|
p: List[Tensor],
|
@@ -122,6 +133,47 @@ def schedule_free_(
|
|
122
133
|
return weight_sum
|
123
134
|
|
124
135
|
|
136
|
+
@decorator_knowngood
|
137
|
+
def _compilable_msam(
|
138
|
+
lr: Tensor,
|
139
|
+
beta1: Tensor,
|
140
|
+
param: List[Tensor],
|
141
|
+
z: List[Tensor],
|
142
|
+
update: List[Tensor],
|
143
|
+
grad: List[Tensor],
|
144
|
+
exp_avg: List[Tensor],
|
145
|
+
caution: bool,
|
146
|
+
decay: Tensor,
|
147
|
+
sam_step_size: Tensor,
|
148
|
+
):
|
149
|
+
exp_avg32 = _lerp(exp_avg, update, beta1)
|
150
|
+
for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
|
151
|
+
u_ = u_.view_as(z_)
|
152
|
+
z32_ = promote(z_)
|
153
|
+
if caution:
|
154
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
155
|
+
z32_ = z32_ * (1 - decay * lr) + u_ * -lr
|
156
|
+
copy_stochastic_(z_, z32_)
|
157
|
+
copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
|
158
|
+
|
159
|
+
|
160
|
+
def msam_(
|
161
|
+
lr: float,
|
162
|
+
beta1: float,
|
163
|
+
param: List[Tensor],
|
164
|
+
z: List[Tensor],
|
165
|
+
update: List[Tensor],
|
166
|
+
grad: List[Tensor],
|
167
|
+
exp_avg: List[Tensor],
|
168
|
+
caution: bool,
|
169
|
+
weight_decay: float,
|
170
|
+
sam_step_size: float,
|
171
|
+
):
|
172
|
+
param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
|
173
|
+
lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
|
174
|
+
_compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
|
175
|
+
|
176
|
+
|
125
177
|
def append_or_extend(base, new):
|
126
178
|
if isinstance(new, list):
|
127
179
|
base.extend(new)
|
@@ -161,7 +213,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
161
213
|
new_shape = [grad.shape[0], *new_shape[::-1]]
|
162
214
|
new_grad = grad.reshape(new_shape)
|
163
215
|
if not split:
|
164
|
-
return new_grad
|
216
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
165
217
|
|
166
218
|
grads = [new_grad]
|
167
219
|
for i, sh in reversed(list(enumerate(new_shape[:]))):
|
@@ -172,7 +224,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
172
224
|
continue
|
173
225
|
grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
|
174
226
|
if len(grads) == 1:
|
175
|
-
return new_grad
|
227
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
176
228
|
new_grads = []
|
177
229
|
for g in grads:
|
178
230
|
append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
|
@@ -189,14 +241,14 @@ def eps_sqrt(item, eps):
|
|
189
241
|
|
190
242
|
@decorator_knowngood
|
191
243
|
def _compilable_exp_avg_sq_(
|
192
|
-
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[
|
244
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
|
193
245
|
):
|
194
246
|
g32 = promote(grad)
|
195
247
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
196
248
|
|
197
249
|
denom = [eps_sqrt(d, eps) for d in s32]
|
198
250
|
|
199
|
-
if out[0] is None:
|
251
|
+
if out is None or out[0] is None:
|
200
252
|
return denom
|
201
253
|
|
202
254
|
copy_stochastic_list_(out, denom)
|
@@ -265,8 +317,8 @@ def adaptive_gradient_clipping_(
|
|
265
317
|
def is_compiling():
|
266
318
|
try:
|
267
319
|
return torch.compiler.is_compiling()
|
268
|
-
except TorchDynamoException:
|
269
|
-
return
|
320
|
+
except (TorchDynamoException, AttributeError):
|
321
|
+
return False
|
270
322
|
|
271
323
|
|
272
324
|
def set_(dst: Tensor, src: Tensor):
|
@@ -279,16 +331,29 @@ def clean():
|
|
279
331
|
|
280
332
|
|
281
333
|
def _ignore_warning(msg):
|
282
|
-
warnings.filterwarnings("ignore", f".*{msg}.*")
|
334
|
+
warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
|
283
335
|
|
284
336
|
|
285
|
-
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
337
|
+
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
|
338
|
+
import opt_einsum as _opt_einsum
|
339
|
+
|
286
340
|
cudnn.benchmark = True
|
287
341
|
cudnn.deterministic = False
|
288
342
|
cudnn.benchmark_limit = benchmark_limit
|
289
343
|
torch.use_deterministic_algorithms(False)
|
290
344
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
291
|
-
opt_einsum.set_flags(True
|
345
|
+
opt_einsum.set_flags(True)
|
346
|
+
if einsum_strategy == "heavyball":
|
347
|
+
opt_einsum.strategy = "auto-hq"
|
348
|
+
choices = _opt_einsum.paths._AUTO_HQ_CHOICES
|
349
|
+
for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
|
350
|
+
if isinstance(fn, int):
|
351
|
+
fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
|
352
|
+
for i in range(max(choices.keys()), max_val):
|
353
|
+
if i not in choices:
|
354
|
+
choices[i] = fn
|
355
|
+
else:
|
356
|
+
opt_einsum.strategy = einsum_strategy
|
292
357
|
|
293
358
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
294
359
|
_ignore_warning(
|
@@ -297,32 +362,39 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
|
297
362
|
_ignore_warning(
|
298
363
|
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
364
|
)
|
365
|
+
_ignore_warning(
|
366
|
+
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
|
367
|
+
)
|
300
368
|
|
301
369
|
|
302
|
-
@
|
370
|
+
@decorator_knowngood
|
303
371
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
304
|
-
assert
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
if G.
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
372
|
+
assert (
|
373
|
+
G.ndim >= 2
|
374
|
+
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
375
|
+
assert steps == 5
|
376
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
377
|
+
if G.size(-2) > G.size(-1):
|
378
|
+
X = X.mT
|
379
|
+
|
380
|
+
stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
|
381
|
+
# Perform the NS iterations
|
382
|
+
for a, b, c in [
|
383
|
+
(4.0848, -6.8946, 2.9270),
|
384
|
+
(3.9505, -6.3029, 2.6377),
|
385
|
+
(3.7418, -5.5913, 2.3037),
|
386
|
+
(2.8769, -3.1427, 1.2046),
|
387
|
+
(2.8366, -3.0525, 1.2012),
|
388
|
+
]:
|
389
|
+
A = X @ X.mT
|
390
|
+
B = (
|
391
|
+
b * A + c * A @ A
|
392
|
+
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
313
393
|
X = a * X + B @ X
|
314
|
-
if G.size(0) > G.size(1):
|
315
|
-
X = X.T
|
316
|
-
return X.to(G.dtype)
|
317
|
-
|
318
394
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
if zeroth_power_mode == "svd":
|
323
|
-
u, _s, v = torch.linalg.svd(x)
|
324
|
-
return u @ v.T
|
325
|
-
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
395
|
+
if G.size(-2) > G.size(-1):
|
396
|
+
X = X.mT
|
397
|
+
return X.to(G.dtype)
|
326
398
|
|
327
399
|
|
328
400
|
@decorator_knowngood
|
@@ -377,7 +449,7 @@ def _compilable_grafting(magnitude, direction):
|
|
377
449
|
|
378
450
|
|
379
451
|
@decorator_knowngood
|
380
|
-
def
|
452
|
+
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
381
453
|
if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
|
382
454
|
y = zeropower_via_newtonschulz5(x, 5)
|
383
455
|
elif mode == "qr":
|
@@ -395,9 +467,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
|
395
467
|
y = _compilable_grafting(x, y)
|
396
468
|
else:
|
397
469
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
470
|
+
if out is None:
|
471
|
+
return y
|
472
|
+
|
398
473
|
set_(out, y)
|
399
474
|
|
400
475
|
|
476
|
+
def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
|
477
|
+
return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
|
478
|
+
|
479
|
+
|
401
480
|
@decorator_knowngood
|
402
481
|
def _compilable_scatter_set(target, source, index):
|
403
482
|
target[:] = source.contiguous()[index].reshape_as(target)
|
@@ -413,6 +492,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
413
492
|
:param Q: List of current eigenbases (updated in-place to Q_new).
|
414
493
|
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
415
494
|
"""
|
495
|
+
if exp_avg.dim() == 0: # preconditioning doesn't make sense here
|
496
|
+
Q.clear()
|
497
|
+
return
|
498
|
+
|
416
499
|
if isinstance(Q, list) and not Q:
|
417
500
|
return
|
418
501
|
|
@@ -430,10 +513,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
430
513
|
q_old = promote(q.data)
|
431
514
|
|
432
515
|
tmp = m @ q_old
|
433
|
-
est_eig =
|
516
|
+
est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
|
434
517
|
sort_idx = torch.argsort(est_eig, descending=True)
|
435
518
|
|
436
|
-
tmp[:, sort_idx]
|
519
|
+
tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
|
437
520
|
new_qs.append(tmp)
|
438
521
|
|
439
522
|
if exp_avg is None:
|
@@ -453,7 +536,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
453
536
|
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
454
537
|
|
455
538
|
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
456
|
-
exp_avg_new =
|
539
|
+
exp_avg_new = compiled_einsum(
|
457
540
|
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
541
|
)
|
459
542
|
copy_stochastic_(exp_avg, exp_avg_new)
|
@@ -487,10 +570,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
487
570
|
except torch.OutOfMemoryError:
|
488
571
|
if m.device.type == "cpu":
|
489
572
|
raise
|
490
|
-
|
573
|
+
if torch.cuda.is_available():
|
574
|
+
torch.cuda.synchronize(m.device)
|
575
|
+
clean()
|
576
|
+
m = m.cpu()
|
577
|
+
except RuntimeError as e:
|
578
|
+
if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
|
579
|
+
torch.cuda.synchronize(m.device)
|
580
|
+
clean()
|
491
581
|
m = m.cpu()
|
492
|
-
|
493
|
-
if m.dtype != torch.double:
|
582
|
+
elif m.dtype != torch.double:
|
494
583
|
m = m.double()
|
495
584
|
elif eps < max_eps:
|
496
585
|
eps = eps ** (2 / 3)
|
@@ -568,6 +657,19 @@ def scalar_guard(*args):
|
|
568
657
|
return out
|
569
658
|
|
570
659
|
|
660
|
+
def broadcastable_list_guard(*xs):
|
661
|
+
xs = list_guard(*xs)
|
662
|
+
for x in xs:
|
663
|
+
if isinstance(x[0], Tensor):
|
664
|
+
ref = x[0]
|
665
|
+
break
|
666
|
+
else:
|
667
|
+
raise ValueError("No tensor-valued input given")
|
668
|
+
xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
|
669
|
+
max_len = max(len(x) for x in xs)
|
670
|
+
return [x if len(x) > 1 else x * max_len for x in xs]
|
671
|
+
|
672
|
+
|
571
673
|
@decorator_knowngood
|
572
674
|
def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
573
675
|
for x_, y_ in zip(x, y):
|
@@ -576,8 +678,8 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
576
678
|
copy_stochastic_(x_, x32 + y32 * alpha)
|
577
679
|
|
578
680
|
|
579
|
-
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
580
|
-
x, y =
|
681
|
+
def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
|
682
|
+
x, y = broadcastable_list_guard(x, y)
|
581
683
|
alpha = scalar_guard(alpha, x[0])
|
582
684
|
_compilable_stochastic_add_(x, y, alpha)
|
583
685
|
|
@@ -590,8 +692,10 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
|
|
590
692
|
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
591
693
|
|
592
694
|
|
593
|
-
def stochastic_add_divide_(
|
594
|
-
x, y =
|
695
|
+
def stochastic_add_divide_(
|
696
|
+
x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
|
697
|
+
):
|
698
|
+
x, y = broadcastable_list_guard(x, y)
|
595
699
|
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
700
|
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
701
|
|
@@ -604,8 +708,8 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
604
708
|
copy_stochastic_(x_, x32 * y32)
|
605
709
|
|
606
710
|
|
607
|
-
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
608
|
-
x, y =
|
711
|
+
def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
|
712
|
+
x, y = broadcastable_list_guard(x, y)
|
609
713
|
_compilable_stochastic_multiply_(x, y)
|
610
714
|
|
611
715
|
|
@@ -624,7 +728,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
624
728
|
b = einsum_base[idx]
|
625
729
|
g0 = einsum_base[: grad.dim()]
|
626
730
|
g1 = g0.replace(b, b.upper())
|
627
|
-
outer_product =
|
731
|
+
outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
|
628
732
|
stochastic_lerp_(m, outer_product, 1 - beta)
|
629
733
|
|
630
734
|
|
@@ -706,7 +810,7 @@ def project(grad, Q, back: bool):
|
|
706
810
|
preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
707
811
|
if preconditioners:
|
708
812
|
out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
|
709
|
-
out =
|
813
|
+
out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
|
710
814
|
grad = out.to(grad.dtype)
|
711
815
|
return grad
|
712
816
|
|
@@ -714,24 +818,28 @@ def project(grad, Q, back: bool):
|
|
714
818
|
@contextlib.contextmanager
|
715
819
|
def patch_backward():
|
716
820
|
@contextlib.contextmanager
|
717
|
-
def
|
821
|
+
def patch_module(module):
|
718
822
|
original = module.backward
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
823
|
+
try:
|
824
|
+
signature = inspect.signature(original)
|
825
|
+
|
826
|
+
@functools.wraps(original)
|
827
|
+
def patched_backward(*args, **kwargs):
|
828
|
+
new_kwargs = signature.bind(*args)
|
829
|
+
new_kwargs.apply_defaults()
|
830
|
+
new_kwargs = new_kwargs.arguments
|
831
|
+
new_kwargs.update(kwargs)
|
832
|
+
new_kwargs["create_graph"] = True
|
833
|
+
return original(**new_kwargs)
|
834
|
+
|
835
|
+
module.backward = patched_backward
|
836
|
+
yield
|
837
|
+
finally:
|
838
|
+
module.backward = original
|
839
|
+
|
840
|
+
with contextlib.ExitStack() as stack:
|
841
|
+
stack.enter_context(patch_module(torch.Tensor))
|
842
|
+
stack.enter_context(patch_module(torch.autograd))
|
735
843
|
yield
|
736
844
|
|
737
845
|
|
@@ -743,6 +851,13 @@ class ExactHVPFailed(ValueError):
|
|
743
851
|
pass
|
744
852
|
|
745
853
|
|
854
|
+
use_default = object()
|
855
|
+
|
856
|
+
|
857
|
+
def _tensor_key(x: Tensor):
|
858
|
+
return x.data_ptr(), x.numel(), x.dtype, x.device
|
859
|
+
|
860
|
+
|
746
861
|
class StatefulOptimizer(torch.optim.Optimizer):
|
747
862
|
"""
|
748
863
|
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
@@ -755,7 +870,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
755
870
|
compile_step: bool = False
|
756
871
|
hessian_approx: bool = False
|
757
872
|
precond_schedule: Union[Callable, float, None] = None
|
758
|
-
stochastic_schedule: bool =
|
873
|
+
stochastic_schedule: bool | Literal[use_default] = use_default
|
759
874
|
finite_differences: bool = False
|
760
875
|
fallback_to_finite_differences: bool = True
|
761
876
|
_fallback_enabled: bool = False
|
@@ -765,18 +880,62 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
765
880
|
super().__init__(params, {**defaults, "foreach": foreach})
|
766
881
|
self.use_ema = use_ema
|
767
882
|
self.mapping = {}
|
768
|
-
self.
|
769
|
-
|
883
|
+
self.mapping_inverse = {}
|
884
|
+
|
885
|
+
if self.stochastic_schedule is use_default:
|
886
|
+
stochastic_schedule = None
|
887
|
+
for group in self.param_groups:
|
888
|
+
new = group.get("stochastic_schedule", stochastic_schedule)
|
889
|
+
if stochastic_schedule is not None and new != stochastic_schedule:
|
890
|
+
raise ValueError("All parameter groups must have the same stochastic_schedule.")
|
891
|
+
stochastic_schedule = new
|
892
|
+
self.stochastic_schedule = stochastic_schedule
|
893
|
+
|
894
|
+
self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
|
895
|
+
self.precond_rng = random.Random(0x12312)
|
770
896
|
self._is_preconditioning = None
|
771
897
|
|
772
898
|
if self.hessian_approx and self.compile_step:
|
773
899
|
raise ValueError("Hessian approximation can't be used with compile_step.")
|
774
900
|
|
901
|
+
self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
|
902
|
+
self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
|
903
|
+
|
904
|
+
def _store_stats(self, state_dict: dict[str, any]):
|
905
|
+
state_dict["heavyball"] = {
|
906
|
+
"inner_group": self.inner_group,
|
907
|
+
"precond_rng": pickle.dumps(self.precond_rng),
|
908
|
+
"use_ema": self.use_ema,
|
909
|
+
"ema_decay": self.ema_decay,
|
910
|
+
"compile_step": self.compile_step,
|
911
|
+
"hessian_approx": self.hessian_approx,
|
912
|
+
"precond_schedule": pickle.dumps(self.precond_schedule),
|
913
|
+
"stochastic_schedule": self.stochastic_schedule,
|
914
|
+
"fallback_to_finite_differences": self.fallback_to_finite_differences,
|
915
|
+
"_fallback_enabled": self._fallback_enabled,
|
916
|
+
"hvp_interval": self.hvp_interval,
|
917
|
+
}
|
918
|
+
|
919
|
+
def _load_stats(self, state_dict):
|
920
|
+
sd = state_dict.pop("heavyball", {})
|
921
|
+
for k, v in sd.items():
|
922
|
+
if k in ("precond_rng", "precond_schedule"):
|
923
|
+
v = pickle.loads(v)
|
924
|
+
setattr(self, k, v)
|
925
|
+
|
775
926
|
def get_groups(self, group):
|
776
927
|
return [group]
|
777
928
|
|
778
|
-
|
779
|
-
|
929
|
+
@functools.lru_cache(maxsize=None)
|
930
|
+
def state_(self, arg: Tensor, fail: bool = True):
|
931
|
+
if not fail and arg not in self.mapping:
|
932
|
+
return {}
|
933
|
+
if _tensor_key(arg) not in self.mapping_inverse:
|
934
|
+
self._init_mapping()
|
935
|
+
state_param, index = self.mapping_inverse[_tensor_key(arg)]
|
936
|
+
if state_param not in self.state:
|
937
|
+
self.state[state_param] = collections.defaultdict(dict)
|
938
|
+
return self.state[state_param][index]
|
780
939
|
|
781
940
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
782
941
|
for p, g in zip(p_list, g_list):
|
@@ -786,6 +945,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
786
945
|
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
787
946
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
788
947
|
|
948
|
+
def _init_mapping(self, group: dict | None = None):
|
949
|
+
if group is None:
|
950
|
+
for group in self.param_groups:
|
951
|
+
self._init_mapping(group)
|
952
|
+
return
|
953
|
+
|
954
|
+
for p in group["params"]:
|
955
|
+
if p not in self.mapping:
|
956
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
957
|
+
for i, pv in enumerate(p_views):
|
958
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
959
|
+
|
789
960
|
def split_p_and_g_in_group(
|
790
961
|
self,
|
791
962
|
group: dict,
|
@@ -805,10 +976,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
805
976
|
yield p, grad
|
806
977
|
continue
|
807
978
|
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
self.mapping[p] = p_views = merge_group(group, p)
|
979
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
980
|
+
for i, pv in enumerate(p_views):
|
981
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
812
982
|
|
813
983
|
vector = getattr(p, "vector", None)
|
814
984
|
hessian_vector = getattr(p, "hessian_vector", None)
|
@@ -957,8 +1127,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
957
1127
|
raise ValueError("Hessian approximation requires a closure.")
|
958
1128
|
return None
|
959
1129
|
|
960
|
-
step = self.
|
961
|
-
if not hessian_approx or step % self.hvp_interval == 0:
|
1130
|
+
step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
|
1131
|
+
if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
|
962
1132
|
with torch.enable_grad():
|
963
1133
|
loss = closure()
|
964
1134
|
return loss
|
@@ -997,12 +1167,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
997
1167
|
if self.precond_schedule is None:
|
998
1168
|
self._is_preconditioning = False
|
999
1169
|
else:
|
1000
|
-
self._is_preconditioning = psgd_should_update(self.
|
1170
|
+
self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
|
1001
1171
|
loss = self._handle_closure(closure)
|
1002
1172
|
|
1003
1173
|
# we assume that parameters are constant and that there are no excessive recompiles
|
1004
1174
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
1005
1175
|
for group in self.param_groups:
|
1176
|
+
if "param_count" not in group:
|
1177
|
+
group["param_count"] = sum(p.numel() for p in group["params"])
|
1006
1178
|
group["is_preconditioning"] = self._is_preconditioning
|
1007
1179
|
self._step(group)
|
1008
1180
|
if self.use_ema:
|
@@ -1105,7 +1277,7 @@ def fused_adam_(
|
|
1105
1277
|
caution: bool,
|
1106
1278
|
):
|
1107
1279
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
1108
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
1280
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
|
1109
1281
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
1110
1282
|
|
1111
1283
|
|
@@ -1184,7 +1356,7 @@ def fused_laprop_(
|
|
1184
1356
|
eps: float = 1e-8,
|
1185
1357
|
):
|
1186
1358
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1187
|
-
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
1359
|
+
beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
|
1188
1360
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
1189
1361
|
|
1190
1362
|
|
@@ -1203,7 +1375,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
1203
1375
|
|
1204
1376
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
1205
1377
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1206
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
1378
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
|
1207
1379
|
_fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
1208
1380
|
|
1209
1381
|
|
@@ -1233,11 +1405,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
|
1233
1405
|
|
1234
1406
|
|
1235
1407
|
@decorator_knowngood
|
1236
|
-
def stochastic_round_(ref: Tensor, source: Tensor):
|
1237
|
-
if source
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1408
|
+
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
|
1409
|
+
if source is not None:
|
1410
|
+
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
1411
|
+
return source
|
1412
|
+
if ref.dtype != torch.bfloat16:
|
1413
|
+
return source.to(ref.dtype)
|
1414
|
+
else:
|
1415
|
+
source = ref
|
1416
|
+
source = source.float()
|
1241
1417
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
1242
1418
|
result.add_(source.view(dtype=torch.int32))
|
1243
1419
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
@@ -1306,83 +1482,115 @@ def stable_exp(x: Tensor):
|
|
1306
1482
|
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1307
1483
|
|
1308
1484
|
|
1485
|
+
def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
|
1486
|
+
# ln(mean(x ** pow) ** (1 / pow / 2))
|
1487
|
+
normalization = math.log(x.numel())
|
1488
|
+
x = x.double()
|
1489
|
+
x = x.abs()
|
1490
|
+
x = x.clamp(min=eps)
|
1491
|
+
x = x.log()
|
1492
|
+
x = x * pow
|
1493
|
+
x = x.flatten()
|
1494
|
+
x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
|
1495
|
+
x = x - normalization # sum -> mean (divide by x.numel() in log space)
|
1496
|
+
return x / pow / 2
|
1497
|
+
|
1498
|
+
|
1309
1499
|
@decorator_knowngood
|
1310
1500
|
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
1501
|
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
-
|
1313
|
-
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
-
return stable_exp(-log_mean_x_pow / pow / 2)
|
1502
|
+
return stable_exp(-_lse_mean(x, pow, eps))
|
1315
1503
|
|
1316
1504
|
|
1317
1505
|
@decorator_knowngood
|
1318
1506
|
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
1507
|
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
-
|
1321
|
-
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
-
|
1323
|
-
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
-
x_normed = x_normed / pow0 / 2
|
1508
|
+
return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
|
1325
1509
|
|
1326
|
-
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
-
y_normed = y_normed / pow1 / 2
|
1328
1510
|
|
1329
|
-
|
1511
|
+
class PrecondInitError(ValueError):
|
1512
|
+
pass
|
1330
1513
|
|
1331
1514
|
|
1332
|
-
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float =
|
1515
|
+
def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
|
1333
1516
|
automatic_scale = True
|
1334
1517
|
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1518
|
+
scale_scale = 1 if scale_scale is None else scale_scale
|
1335
1519
|
|
1336
1520
|
if scale is not None:
|
1337
1521
|
automatic_scale = False
|
1338
1522
|
warn_once(
|
1339
1523
|
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1340
1524
|
)
|
1341
|
-
if scale_scale
|
1525
|
+
if scale_scale != 1:
|
1526
|
+
warn_once(
|
1527
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
|
1528
|
+
)
|
1529
|
+
if scale_power is not None:
|
1342
1530
|
warn_once(
|
1343
|
-
"
|
1531
|
+
"precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
|
1344
1532
|
)
|
1345
1533
|
elif hessian_vector is None:
|
1346
1534
|
scale = mean_root(grad, 4) * scale_scale
|
1347
1535
|
else:
|
1348
1536
|
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1349
1537
|
|
1538
|
+
if automatic_scale:
|
1539
|
+
scale_power = 0.5 if scale_power is None else scale_power
|
1540
|
+
scale = scale**scale_power
|
1541
|
+
|
1350
1542
|
if isinstance(scale, torch.Tensor):
|
1351
1543
|
scale = scale.item() # slow, but necessary
|
1352
1544
|
|
1353
1545
|
if np.isfinite(scale):
|
1354
|
-
if scale > scale_max
|
1546
|
+
if scale > scale_max: # fallthrough to later checks
|
1355
1547
|
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1356
1548
|
else:
|
1357
1549
|
return scale
|
1358
1550
|
|
1359
1551
|
if not automatic_scale:
|
1360
|
-
raise
|
1552
|
+
raise PrecondInitError("The manually set precond_init_scale is not finite")
|
1361
1553
|
|
1362
1554
|
for x in (grad, hessian_vector, vector):
|
1363
1555
|
if x is None:
|
1364
1556
|
continue
|
1365
|
-
if torch.allclose(x, torch.zeros_like(x))
|
1366
|
-
raise
|
1557
|
+
if torch.allclose(x, torch.zeros_like(x)):
|
1558
|
+
raise PrecondInitError(
|
1559
|
+
f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
|
1560
|
+
)
|
1367
1561
|
if not torch.isfinite(x).all().item():
|
1368
|
-
raise
|
1562
|
+
raise PrecondInitError("Grad or HVP is not finite")
|
1369
1563
|
|
1370
1564
|
if np.isfinite(scale):
|
1371
1565
|
return scale
|
1372
1566
|
|
1373
|
-
raise
|
1567
|
+
raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1374
1568
|
|
1375
1569
|
|
1376
|
-
def init_lra(
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1570
|
+
def init_lra(
|
1571
|
+
grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
|
1572
|
+
):
|
1573
|
+
# "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
|
1574
|
+
# https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
|
1575
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1576
|
+
uv_scale = (param_count * (rank + eps)) ** -0.5
|
1577
|
+
U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1578
|
+
V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1380
1579
|
d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
|
1381
1580
|
return U, V, d
|
1382
1581
|
|
1383
1582
|
|
1384
1583
|
def init_Q_exprs(
|
1385
|
-
grad,
|
1584
|
+
grad,
|
1585
|
+
scale,
|
1586
|
+
scale_scale,
|
1587
|
+
scale_power,
|
1588
|
+
max_size,
|
1589
|
+
min_ndim_triangular,
|
1590
|
+
memory_save_mode,
|
1591
|
+
hessian_vector,
|
1592
|
+
vector,
|
1593
|
+
dtype=None,
|
1386
1594
|
):
|
1387
1595
|
"""
|
1388
1596
|
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
@@ -1391,21 +1599,13 @@ def init_Q_exprs(
|
|
1391
1599
|
precond init scale computation from
|
1392
1600
|
https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
|
1393
1601
|
"""
|
1394
|
-
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1395
|
-
letters = string.ascii_lowercase + string.ascii_uppercase
|
1602
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1396
1603
|
dtype = dtype if dtype is not None else grad.dtype
|
1397
1604
|
shape = grad.shape
|
1398
1605
|
|
1399
1606
|
if len(shape) == 0: # scalar
|
1400
1607
|
Q = [scale * torch.ones_like(grad, dtype=dtype)]
|
1401
|
-
|
1402
|
-
exprGs = [",->"]
|
1403
|
-
exprP = ",,->"
|
1404
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1405
|
-
|
1406
|
-
# Tensor
|
1407
|
-
if len(shape) > 13:
|
1408
|
-
raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
|
1608
|
+
return Q
|
1409
1609
|
|
1410
1610
|
scale = scale ** (1 / len(shape))
|
1411
1611
|
|
@@ -1418,6 +1618,9 @@ def init_Q_exprs(
|
|
1418
1618
|
sorted_shape = sorted(shape)
|
1419
1619
|
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1420
1620
|
dim_diag[_max_idx(shape)] = True
|
1621
|
+
elif memory_save_mode == "one_triu":
|
1622
|
+
shape_ranks = np.argsort(np.argsort(shape)) # ranks
|
1623
|
+
dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
|
1421
1624
|
elif memory_save_mode == "all_diag":
|
1422
1625
|
dim_diag = [True for _ in shape]
|
1423
1626
|
else:
|
@@ -1427,66 +1630,90 @@ def init_Q_exprs(
|
|
1427
1630
|
)
|
1428
1631
|
|
1429
1632
|
Q = []
|
1430
|
-
piece1A, piece2A, piece3A = ([], "", "")
|
1431
|
-
exprGs = []
|
1432
|
-
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
|
1433
1633
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
1434
1634
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
1435
1635
|
# use diagonal matrix as preconditioner for this dim
|
1436
1636
|
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
|
1437
|
-
|
1438
|
-
piece1A.append(letters[i])
|
1439
|
-
piece2A = piece2A + letters[i]
|
1440
|
-
piece3A = piece3A + letters[i]
|
1441
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1442
|
-
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1443
|
-
exprGs.append(subscripts)
|
1444
|
-
piece1P.append(letters[i + 13])
|
1445
|
-
piece2P.append(letters[i + 13])
|
1446
|
-
piece3P = piece3P + letters[i + 13]
|
1447
|
-
piece4P = piece4P + letters[i + 13]
|
1448
1637
|
else:
|
1449
1638
|
# use triangular matrix as preconditioner for this dim
|
1450
1639
|
Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
|
1451
|
-
|
1452
|
-
piece2A = piece2A + letters[i + 13]
|
1453
|
-
piece3A = piece3A + letters[i]
|
1454
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1455
|
-
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1456
|
-
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
|
1457
|
-
exprGs.append(subscripts)
|
1458
|
-
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1459
|
-
piece1P.append(a + b)
|
1460
|
-
piece2P.append(a + c)
|
1461
|
-
piece3P = piece3P + c
|
1462
|
-
piece4P = piece4P + b
|
1463
|
-
|
1464
|
-
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
1465
|
-
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
1466
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1640
|
+
return Q
|
1467
1641
|
|
1468
1642
|
|
1469
|
-
@
|
1470
|
-
def psgd_balance_Q(
|
1471
|
-
norms =
|
1472
|
-
geometric_mean = norms
|
1473
|
-
|
1474
|
-
|
1643
|
+
@decorator_knowngood
|
1644
|
+
def psgd_balance_Q(Q):
|
1645
|
+
norms = [promote(q.norm(float("inf"))).log() for q in Q]
|
1646
|
+
geometric_mean = sum([n for n in norms]) / len(Q)
|
1647
|
+
for q, n in zip(Q, norms):
|
1648
|
+
q *= (geometric_mean - n).exp()
|
1475
1649
|
|
1476
1650
|
|
1477
|
-
@
|
1478
|
-
def
|
1479
|
-
u_norm =
|
1480
|
-
v_norm =
|
1481
|
-
scale = (u_norm / v_norm) ** 0.
|
1482
|
-
|
1483
|
-
|
1651
|
+
@decorator_knowngood
|
1652
|
+
def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
|
1653
|
+
u_norm = sum(u.square().sum().double() for u in U)
|
1654
|
+
v_norm = sum(v.square().sum().double() for v in V)
|
1655
|
+
scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
|
1656
|
+
scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
|
1657
|
+
stochastic_multiply_(U, [1 / scale] * len(U))
|
1658
|
+
stochastic_multiply_(V, [scale] * len(V))
|
1659
|
+
return multi_flatten((U, 1), (V, 1), (d, 0))
|
1484
1660
|
|
1485
1661
|
|
1486
1662
|
@decorator
|
1487
1663
|
def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
1488
1664
|
dtype = min_dtype([U, V, x])
|
1489
|
-
return x +
|
1665
|
+
return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1666
|
+
|
1667
|
+
|
1668
|
+
@decorator_knowngood
|
1669
|
+
def _compilable_d_step(
|
1670
|
+
d: Tensor,
|
1671
|
+
d_orig: List[Tensor],
|
1672
|
+
invQtv: Tensor,
|
1673
|
+
vector: Tensor,
|
1674
|
+
inverse_precond_vector: Tensor,
|
1675
|
+
hessian_vector: Tensor,
|
1676
|
+
precond_hessian_vector: Tensor,
|
1677
|
+
eps: Tensor,
|
1678
|
+
step: Tensor,
|
1679
|
+
delayed: bool,
|
1680
|
+
):
|
1681
|
+
precond_hessian_vector = promote(precond_hessian_vector)
|
1682
|
+
hessian_vector = promote(hessian_vector)
|
1683
|
+
vector = promote(vector)
|
1684
|
+
inverse_precond_vector = promote(inverse_precond_vector)
|
1685
|
+
invQtv = promote(invQtv)
|
1686
|
+
inverse_precond_vector = invQtv - inverse_precond_vector
|
1687
|
+
|
1688
|
+
nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
|
1689
|
+
|
1690
|
+
"""
|
1691
|
+
1) Sketching
|
1692
|
+
1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
|
1693
|
+
1.2) reduced-precision selection of largest element (halves memory traffic)
|
1694
|
+
2) Computation
|
1695
|
+
2.1) select relevant indices
|
1696
|
+
2.2) redo 1.1 in double precision for scalar values
|
1697
|
+
2.3) return high-precision normalized step-size
|
1698
|
+
overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
|
1699
|
+
improving precision
|
1700
|
+
"""
|
1701
|
+
a0 = promote(d) * precond_hessian_vector
|
1702
|
+
a1 = vector
|
1703
|
+
b0 = inverse_precond_vector / promote(d)
|
1704
|
+
b1 = hessian_vector
|
1705
|
+
|
1706
|
+
divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
|
1707
|
+
idx = divisor.bfloat16().flatten().argmax()
|
1708
|
+
a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
|
1709
|
+
b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
|
1710
|
+
divisor = (a * b).sqrt().clamp(min=eps)
|
1711
|
+
step = -step / divisor
|
1712
|
+
|
1713
|
+
# fused update(s)
|
1714
|
+
apply_flat_add(d_orig, nablaD, step)
|
1715
|
+
if not delayed:
|
1716
|
+
copy_stochastic_(d, promote(d) - nablaD * step)
|
1490
1717
|
|
1491
1718
|
|
1492
1719
|
def update_lra_precond_(
|
@@ -1498,13 +1725,14 @@ def update_lra_precond_(
|
|
1498
1725
|
eps: float,
|
1499
1726
|
step: float,
|
1500
1727
|
delayed: bool,
|
1728
|
+
precond_u: bool,
|
1501
1729
|
):
|
1502
1730
|
"""
|
1503
1731
|
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1504
1732
|
"""
|
1505
1733
|
U_orig, V_orig, d_orig = U, V, d
|
1506
1734
|
|
1507
|
-
U, V, d =
|
1735
|
+
U, V, d = _lra_flatten_and_balance(U, V, d)
|
1508
1736
|
|
1509
1737
|
dtype = min_dtype([U, V, vector, hessian_vector])
|
1510
1738
|
U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
|
@@ -1512,10 +1740,10 @@ def update_lra_precond_(
|
|
1512
1740
|
eps = scalar_guard(eps, vector)
|
1513
1741
|
|
1514
1742
|
Qh = low_rank_mm(U, V, d * hessian_vector)
|
1515
|
-
Ph =
|
1743
|
+
Ph = low_rank_mm(V, U, Qh)
|
1516
1744
|
rank = U.size(1)
|
1517
1745
|
|
1518
|
-
VtU =
|
1746
|
+
VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
|
1519
1747
|
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
1520
1748
|
IpVtU = I + VtU
|
1521
1749
|
invQtv = vector / d
|
@@ -1533,47 +1761,39 @@ def update_lra_precond_(
|
|
1533
1761
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1534
1762
|
|
1535
1763
|
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
1536
|
-
invPv =
|
1537
|
-
invPv = invPv / d
|
1538
|
-
|
1539
|
-
nablaD = Ph * hessian_vector - vector * invPv
|
1540
|
-
divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
|
1541
|
-
divisor = divisor.add(eps).sqrt().max()
|
1542
|
-
d_step = step / divisor
|
1764
|
+
invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
1543
1765
|
|
1544
|
-
|
1766
|
+
eps, step = scalar_guard(eps, step, vector)
|
1767
|
+
_compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
|
1545
1768
|
|
1546
1769
|
a, b = Qh, invQtv
|
1547
1770
|
|
1548
|
-
precond_u = random.random() < 0.5 # update either U or V, not both at the same time
|
1549
1771
|
precond = V if precond_u else U
|
1550
|
-
atV =
|
1551
|
-
btV =
|
1552
|
-
atVVt =
|
1553
|
-
btVVt =
|
1554
|
-
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()
|
1772
|
+
atV = compiled_einsum("b,br->r", a, precond) # o == one
|
1773
|
+
btV = compiled_einsum("b,br->r", b, precond)
|
1774
|
+
atVVt = compiled_einsum("r,br->b", atV, precond)
|
1775
|
+
btVVt = compiled_einsum("r,br->b", btV, precond)
|
1776
|
+
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
|
1555
1777
|
if precond_u:
|
1556
|
-
a =
|
1557
|
-
b =
|
1778
|
+
a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
|
1779
|
+
b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
|
1558
1780
|
else:
|
1559
|
-
a = a +
|
1560
|
-
b = b +
|
1561
|
-
a =
|
1562
|
-
b =
|
1781
|
+
a = a + compiled_einsum("br,r->b", V, atV)
|
1782
|
+
b = b + compiled_einsum("br,r->b", V, btV)
|
1783
|
+
a = compiled_einsum("b,r->br", a, atV)
|
1784
|
+
b = compiled_einsum("b,r->br", b, btV)
|
1563
1785
|
apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
|
1564
|
-
|
1565
1786
|
if not delayed:
|
1566
|
-
stochastic_add_([d], [d * nablaD], -d_step)
|
1567
1787
|
stochastic_add_([U if precond_u else V], [b - a], precond_step)
|
1568
1788
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1569
1789
|
|
1570
1790
|
|
1571
|
-
def lra_precond(U, V, d, g):
|
1791
|
+
def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
|
1572
1792
|
"""
|
1573
1793
|
As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
|
1574
1794
|
"""
|
1575
|
-
|
1576
|
-
return d * low_rank_mm(V, U,
|
1795
|
+
new_g = low_rank_mm(U, V, d * g)
|
1796
|
+
return d * low_rank_mm(V, U, new_g)
|
1577
1797
|
|
1578
1798
|
|
1579
1799
|
@decorator_knowngood
|
@@ -1584,16 +1804,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
|
|
1584
1804
|
|
1585
1805
|
|
1586
1806
|
@decorator_knowngood
|
1587
|
-
def
|
1588
|
-
|
1807
|
+
def _compilable_lra_update_(
|
1808
|
+
params: List[Tensor],
|
1809
|
+
update: List[Tensor],
|
1810
|
+
U: Tensor,
|
1811
|
+
V: Tensor,
|
1812
|
+
d: Tensor,
|
1813
|
+
lr: Tensor,
|
1814
|
+
decay: Tensor,
|
1815
|
+
caution: bool,
|
1816
|
+
grads: List[Tensor],
|
1817
|
+
):
|
1818
|
+
update = lra_precond(U, V, d, flatten(update))
|
1589
1819
|
start = 0
|
1590
1820
|
update = update.flatten()
|
1591
|
-
for p in params:
|
1821
|
+
for p, g in zip(params, grads):
|
1592
1822
|
size = p.numel()
|
1593
|
-
|
1823
|
+
update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
|
1594
1824
|
start += size
|
1595
1825
|
|
1596
1826
|
|
1827
|
+
def apply_lra_update(
|
1828
|
+
params: List[Tensor],
|
1829
|
+
update: Tensor,
|
1830
|
+
U: Tensor,
|
1831
|
+
V: Tensor,
|
1832
|
+
d: Tensor,
|
1833
|
+
lr: float,
|
1834
|
+
decay: float,
|
1835
|
+
caution: bool,
|
1836
|
+
grads: List[Tensor],
|
1837
|
+
):
|
1838
|
+
params, grads = list_guard(params, grads)
|
1839
|
+
lr, decay = scalar_guard(lr, decay, params[0])
|
1840
|
+
_compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
|
1841
|
+
|
1842
|
+
|
1597
1843
|
@decorator_knowngood
|
1598
1844
|
def apply_flat_update(params: List[Tensor], update: Tensor):
|
1599
1845
|
start = 0
|
@@ -1604,6 +1850,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
|
|
1604
1850
|
start += size
|
1605
1851
|
|
1606
1852
|
|
1853
|
+
@decorator_knowngood
|
1854
|
+
def zero_(x: List[Tensor]):
|
1855
|
+
for i in x:
|
1856
|
+
i.zero_()
|
1857
|
+
|
1858
|
+
|
1607
1859
|
@decorator_knowngood
|
1608
1860
|
def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
1609
1861
|
start = 0
|
@@ -1629,7 +1881,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
1629
1881
|
@decorator_knowngood
|
1630
1882
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1631
1883
|
last_dim = x[0].shape[-remaining:] if remaining else []
|
1632
|
-
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1884
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
|
1885
|
+
|
1886
|
+
|
1887
|
+
@decorator_knowngood
|
1888
|
+
def multi_flatten(*xs: Tuple[List[Tensor], int]):
|
1889
|
+
return [flatten(x, i) for x, i in xs]
|
1633
1890
|
|
1634
1891
|
|
1635
1892
|
@decorator_knowngood
|
@@ -1645,149 +1902,564 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
|
1645
1902
|
|
1646
1903
|
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1647
1904
|
md = min_dtype(args)
|
1648
|
-
return
|
1905
|
+
return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1649
1906
|
|
1650
1907
|
|
1651
1908
|
@decorator_knowngood
|
1652
1909
|
def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
|
1653
1910
|
triangular_qs = []
|
1911
|
+
conjB = promote(conjB)
|
1654
1912
|
for i, q in enumerate(Qs):
|
1655
1913
|
q = promote(q)
|
1656
1914
|
if q.dim() <= 1:
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
1915
|
+
if conjB.ndim == 0:
|
1916
|
+
conjB = conjB / q
|
1917
|
+
else:
|
1918
|
+
shape = [1] * conjB.ndim
|
1919
|
+
shape[i] = -1
|
1920
|
+
conjB = conjB / q.view(shape)
|
1660
1921
|
else:
|
1661
1922
|
triangular_qs.append((i, q))
|
1662
|
-
return triangular_qs
|
1923
|
+
return triangular_qs, conjB
|
1663
1924
|
|
1664
1925
|
|
1665
1926
|
@decorator_knowngood
|
1666
|
-
def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int,
|
1927
|
+
def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
|
1928
|
+
solved = solved.reshape(transposed_shape)
|
1929
|
+
solved = solved.transpose(-1, last_dim)
|
1667
1930
|
solved = solved.reshape(original_shape)
|
1668
|
-
solved.transpose(
|
1669
|
-
return solved.
|
1931
|
+
solved = solved.transpose(-1, new_dim)
|
1932
|
+
return solved.contiguous(), solved.shape
|
1933
|
+
|
1934
|
+
|
1935
|
+
def ndim_tuple(Q: list[Tensor]) -> tuple:
|
1936
|
+
return tuple(q.ndim for q in Q)
|
1670
1937
|
|
1671
1938
|
|
1672
|
-
def psgd_calc_A_and_conjB(
|
1673
|
-
|
1674
|
-
|
1675
|
-
|
1676
|
-
conjB = conjB.to(promote(G.dtype))
|
1939
|
+
def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
|
1940
|
+
if conjB is None:
|
1941
|
+
conjB = torch.randn_like(G)
|
1942
|
+
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
|
1677
1943
|
A = casted_einsum(exprA, *Q, G)
|
1678
1944
|
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
1679
|
-
original_shape = conjB.shape
|
1945
|
+
transposed_shape = original_shape = conjB.shape
|
1680
1946
|
prev_i = -1
|
1681
|
-
|
1682
|
-
|
1947
|
+
qs, conjB = _psgd_calc_scalars_(Q, conjB)
|
1948
|
+
for i, tri_q in qs:
|
1949
|
+
conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
|
1683
1950
|
prev_i = i
|
1684
1951
|
conjB = solve(tri_q, conjB, upper=True, left=False)
|
1685
|
-
conjB = _reshape_conjB(conjB, original_shape, prev_i,
|
1952
|
+
conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
|
1686
1953
|
return A, conjB
|
1687
1954
|
|
1688
1955
|
|
1689
1956
|
@decorator_knowngood
|
1690
|
-
def
|
1691
|
-
|
1692
|
-
|
1957
|
+
def _random_projection(x: Tensor, scale: Optional[Tensor]):
|
1958
|
+
if scale is None:
|
1959
|
+
scale = x.norm(float("inf")).clamp(min=1e-8)
|
1960
|
+
k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
|
1961
|
+
norm = x.square().sum(0)
|
1962
|
+
indices = torch.topk(norm, k, largest=True).indices
|
1963
|
+
return x.index_select(1, indices).contiguous() / scale, scale
|
1693
1964
|
|
1694
1965
|
|
1695
|
-
def
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
|
1703
|
-
|
1966
|
+
def max_singular_value_exact(A, use_lobpcg: bool = False):
|
1967
|
+
try:
|
1968
|
+
if use_lobpcg:
|
1969
|
+
A = A @ A.T
|
1970
|
+
eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
|
1971
|
+
return eigval[0].sqrt()
|
1972
|
+
else:
|
1973
|
+
return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
|
1974
|
+
except (torch.linalg.LinAlgError, RuntimeError):
|
1975
|
+
return max_singular_value_power_iter(promote(A), iterations=2)
|
1704
1976
|
|
1705
1977
|
|
1706
1978
|
@decorator_knowngood
|
1707
|
-
def
|
1708
|
-
|
1979
|
+
def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
|
1980
|
+
"""
|
1981
|
+
Rayleigh quotient of row with the largest norm + optional power iterations
|
1982
|
+
"""
|
1983
|
+
x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
|
1984
|
+
x_norm = promote(x_norm)
|
1985
|
+
|
1986
|
+
def _inner():
|
1987
|
+
A = A_outer
|
1988
|
+
x = A.index_select(0, max_idx).flatten().contiguous()
|
1989
|
+
A = stochastic_round_(A / x_norm)
|
1990
|
+
x = x / x_norm
|
1991
|
+
|
1992
|
+
def _mv(x):
|
1993
|
+
return promote(A.T.mv(A.mv(stochastic_round_(x))))
|
1994
|
+
|
1995
|
+
for _ in range(iterations):
|
1996
|
+
# A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
|
1997
|
+
x = F.normalize(_mv(x), dim=0)
|
1998
|
+
out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
|
1999
|
+
return out.squeeze().clone()
|
2000
|
+
|
2001
|
+
return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
|
1709
2002
|
|
1710
2003
|
|
1711
2004
|
@decorator_knowngood
|
1712
|
-
def
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
2005
|
+
def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
2006
|
+
"""
|
2007
|
+
Adapted from @evanatyourservice
|
2008
|
+
"""
|
2009
|
+
Y, max_abs = _random_projection(A, max_abs)
|
2010
|
+
Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
|
2011
|
+
Q = Q / max_abs
|
2012
|
+
Z = A.T @ Q
|
2013
|
+
W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
|
2014
|
+
sketch_norm = max_singular_value_exact(Z.T @ W)
|
2015
|
+
return sketch_norm * max_abs
|
2016
|
+
|
2017
|
+
|
2018
|
+
def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2019
|
+
if A.ndim <= 2:
|
2020
|
+
return max_singular_value(A, max_svd, use_cholesky, power_iter)
|
2021
|
+
|
2022
|
+
base = einsum_base[: A.ndim]
|
2023
|
+
A16 = stochastic_round_(A)
|
2024
|
+
squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
|
2025
|
+
svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
|
2026
|
+
svds = torch.stack(svds)
|
2027
|
+
return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
|
1716
2028
|
|
1717
2029
|
|
1718
2030
|
@decorator_knowngood
|
1719
|
-
def
|
1720
|
-
|
2031
|
+
def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2032
|
+
if A.ndim < 2:
|
2033
|
+
return A.abs().max()
|
2034
|
+
if A.ndim > 2:
|
2035
|
+
raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
|
2036
|
+
if min(A.shape) <= max_svd:
|
2037
|
+
return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
|
2038
|
+
if use_cholesky or power_iter < 0:
|
2039
|
+
return max_singular_value_cholesky(A)
|
2040
|
+
return max_singular_value_power_iter(A, None, iterations=power_iter)
|
1721
2041
|
|
1722
2042
|
|
1723
2043
|
@decorator_knowngood
|
1724
|
-
def
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
2044
|
+
def clamped_max_singular_value(
|
2045
|
+
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
|
2046
|
+
) -> Tensor:
|
2047
|
+
norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
|
2048
|
+
out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
|
2049
|
+
return out.clamp(min=min)
|
2050
|
+
|
2051
|
+
|
2052
|
+
@decorator_knowngood
|
2053
|
+
def min_singular_value(
|
2054
|
+
A: Tensor,
|
2055
|
+
power_iter: int = 5,
|
2056
|
+
safety: float = 1.05,
|
2057
|
+
max_svd: int = 32,
|
2058
|
+
):
|
2059
|
+
if A.ndim < 2:
|
2060
|
+
return A.abs().min()
|
2061
|
+
|
2062
|
+
n = A.size(0)
|
2063
|
+
if n <= max_svd:
|
2064
|
+
try:
|
2065
|
+
eigs = torch.linalg.eigvalsh(promote(A))
|
2066
|
+
return eigs.min().to(A.dtype)
|
2067
|
+
except torch.linalg.LinAlgError:
|
2068
|
+
pass
|
2069
|
+
|
2070
|
+
lambda_max_hat = max_singular_value(A, power_iter=power_iter)
|
2071
|
+
lambda_upper = lambda_max_hat * safety
|
2072
|
+
|
2073
|
+
row_norms = A.norm(dim=1)
|
2074
|
+
norm, idx = row_norms.min(dim=0)
|
2075
|
+
v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
|
2076
|
+
|
2077
|
+
v = v / promote(v.norm())
|
2078
|
+
for _ in range(power_iter):
|
2079
|
+
v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
|
2080
|
+
v = v / promote(v.norm())
|
2081
|
+
mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
|
2082
|
+
|
2083
|
+
lambda_min_hat = lambda_upper - mu_hat
|
2084
|
+
|
2085
|
+
def _approx():
|
2086
|
+
mu = A.trace() / n
|
2087
|
+
sigma_square = A.square().sum() / n - mu**2
|
2088
|
+
return mu - (sigma_square / (n - 1)).sqrt()
|
2089
|
+
|
2090
|
+
return cond(
|
2091
|
+
(~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
|
2092
|
+
).squeeze()
|
2093
|
+
|
2094
|
+
|
2095
|
+
@decorator_knowngood
|
2096
|
+
def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
|
2097
|
+
if isinstance(Q[0], tuple):
|
2098
|
+
psgd_balance_Q([o[1] for o in Q])
|
2099
|
+
return line_to_triu(Q, symmetric_output)
|
2100
|
+
psgd_balance_Q(Q)
|
2101
|
+
return Q
|
2102
|
+
|
2103
|
+
|
2104
|
+
@functools.lru_cache(maxsize=None)
|
2105
|
+
def calcG_expr(q_dim, g_dim):
|
2106
|
+
exprs = []
|
2107
|
+
base = einsum_base[:g_dim]
|
2108
|
+
for i, q in enumerate(q_dim):
|
2109
|
+
new = list(base)
|
2110
|
+
if q == 2:
|
2111
|
+
new[i] = "Z"
|
2112
|
+
out = f"{base[i]}Z"
|
2113
|
+
else:
|
2114
|
+
out = base[i]
|
2115
|
+
exprs.append(f"{base},{''.join(new)}->{out}")
|
2116
|
+
return exprs
|
2117
|
+
|
2118
|
+
|
2119
|
+
def eye_like(x: Tensor):
|
2120
|
+
if x.ndim < 2:
|
2121
|
+
return torch.ones_like(x)
|
2122
|
+
assert x.ndim == 2
|
2123
|
+
assert x.size(0) == x.size(1)
|
2124
|
+
return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
|
2125
|
+
|
2126
|
+
|
2127
|
+
@decorator_knowngood
|
2128
|
+
def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
|
2129
|
+
"""
|
2130
|
+
Idea:
|
2131
|
+
G should be zeroth power. So, all Qs together should approximate the G's inverse.
|
2132
|
+
Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
|
2133
|
+
Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
|
2134
|
+
then backprop to L/R jointly.
|
2135
|
+
This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
|
2136
|
+
|
2137
|
+
`psgd_precond_grad` computes LGR for the general (n-dimensional) case
|
2138
|
+
`exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
|
2139
|
+
Args:
|
2140
|
+
G: Gradient that should be orthogonalized
|
2141
|
+
Q: List of preconditioner tensors.
|
2142
|
+
|
2143
|
+
Returns:
|
2144
|
+
- List of gradients with respect to Q (d_Q).
|
2145
|
+
"""
|
2146
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2147
|
+
|
2148
|
+
G16 = stochastic_round_(G)
|
2149
|
+
Q16 = [stochastic_round_(q) for q in Q]
|
2150
|
+
P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
|
2151
|
+
|
2152
|
+
d_P = torch.zeros_like(G)
|
2153
|
+
base = einsum_base[: G.ndim]
|
2154
|
+
for i, exprG in enumerate(exprGs):
|
2155
|
+
pp = compiled_einsum(exprG, P, P)
|
2156
|
+
error = pp - eye_like(pp)
|
2157
|
+
dim = einsum_base[i]
|
2158
|
+
if pp.ndim == 2:
|
2159
|
+
new = dim.upper()
|
2160
|
+
prec = f"{new}{dim}"
|
2161
|
+
else:
|
2162
|
+
new = dim
|
2163
|
+
prec = dim
|
2164
|
+
d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
|
2165
|
+
|
2166
|
+
d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
|
2167
|
+
grads = []
|
2168
|
+
for i, exprG in enumerate(exprGs):
|
2169
|
+
new_q = Q16[:]
|
2170
|
+
new_q[i] = eye_like(new_q[i])
|
2171
|
+
pq = psgd_precond_grad(G16, new_q)
|
2172
|
+
grad = compiled_einsum(exprG, pq, d_P)
|
2173
|
+
if grad.ndim == 2:
|
2174
|
+
grad = (grad + grad.T) / 2
|
2175
|
+
grads.append(grad)
|
2176
|
+
|
2177
|
+
return grads, P.to(G.dtype)
|
2178
|
+
|
2179
|
+
|
2180
|
+
def _inverse_initial_guess(gg):
|
2181
|
+
n = gg.shape[0]
|
2182
|
+
|
2183
|
+
sigma_max = promote(gg.norm())
|
2184
|
+
|
2185
|
+
trace_gg = promote(torch.trace(gg))
|
2186
|
+
sigma_min_approx = trace_gg / (n * sigma_max)
|
2187
|
+
|
2188
|
+
return sigma_max, sigma_min_approx
|
2189
|
+
|
2190
|
+
|
2191
|
+
@decorator_knowngood
|
2192
|
+
def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
|
2193
|
+
k = torch.arange(degree, dtype=torch.float64, device=device)
|
2194
|
+
rotation = (2 * k + 1) * math.pi / (2 * degree)
|
2195
|
+
f = (rotation.cos() + 1 + eps) ** -0.5
|
2196
|
+
rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
|
2197
|
+
coeff0 = f.sum() / degree
|
2198
|
+
coeffs = f @ rotation * 2 / degree
|
2199
|
+
return coeff0.float(), coeffs.float()
|
2200
|
+
|
2201
|
+
|
2202
|
+
@decorator_knowngood
|
2203
|
+
def _psgd_default_preconditioner_grad(
|
2204
|
+
terms: List[Tuple[Tensor, Tensor]],
|
2205
|
+
Q: List[Tensor],
|
2206
|
+
) -> List[Tensor]:
|
2207
|
+
out = []
|
2208
|
+
for q, (x, y) in zip(Q, terms):
|
2209
|
+
x = promote(x)
|
2210
|
+
y = promote(y)
|
2211
|
+
update = x - y
|
2212
|
+
if q.ndim < 2:
|
2213
|
+
update = q * update
|
2214
|
+
else:
|
2215
|
+
update = (q @ update).triu()
|
2216
|
+
out.append(update)
|
2217
|
+
return out
|
1728
2218
|
|
1729
2219
|
|
1730
2220
|
@decorator
|
1731
|
-
def psgd_update_precond(
|
2221
|
+
def psgd_update_precond(
|
2222
|
+
G: Tensor,
|
2223
|
+
precond_lr: float,
|
2224
|
+
oq: "TriuOrLine",
|
2225
|
+
store_triu_as_line: bool,
|
2226
|
+
velocity: Optional[List[Tensor]],
|
2227
|
+
beta2: float,
|
2228
|
+
ortho_method: Optional[str],
|
2229
|
+
V: Tensor,
|
2230
|
+
running_lower_bound: List[Tensor],
|
2231
|
+
lower_bount_beta: float,
|
2232
|
+
power_iter: int,
|
2233
|
+
) -> None:
|
1732
2234
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1733
|
-
|
1734
|
-
|
1735
|
-
precond_lr = scalar_guard(precond_lr, G)
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
2235
|
+
Q = _balance_to_triu(oq)
|
2236
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2237
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2238
|
+
|
2239
|
+
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
|
2240
|
+
terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
|
2241
|
+
del A, conjB, V
|
2242
|
+
updates = _psgd_default_preconditioner_grad(terms, Q)
|
2243
|
+
_psgd_precond_update_(
|
2244
|
+
updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2245
|
+
)
|
2246
|
+
return None
|
2247
|
+
|
2248
|
+
|
2249
|
+
@decorator_knowngood
|
2250
|
+
def bf16_matmul(x: Tensor, y: Tensor):
|
2251
|
+
return (promote(x) @ promote(y)).to(x.dtype)
|
2252
|
+
|
2253
|
+
|
2254
|
+
def if_iscompiling(fn):
|
2255
|
+
base = getattr(torch, fn.__name__, None)
|
2256
|
+
|
2257
|
+
def _fn(x):
|
2258
|
+
if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
|
2259
|
+
return base(x)
|
2260
|
+
return fn(x)
|
2261
|
+
|
2262
|
+
return _fn
|
2263
|
+
|
2264
|
+
|
2265
|
+
@if_iscompiling
|
2266
|
+
def while_loop(cond, body, state):
|
2267
|
+
"""
|
2268
|
+
dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
|
2269
|
+
useful for debugging
|
2270
|
+
"""
|
2271
|
+
while cond(*state).item():
|
2272
|
+
state = body(*state)
|
2273
|
+
return state
|
2274
|
+
|
2275
|
+
|
2276
|
+
@if_iscompiling
|
2277
|
+
def cond(cond, true_fn, false_fn):
|
2278
|
+
"""
|
2279
|
+
dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
|
2280
|
+
useful for debugging
|
2281
|
+
"""
|
2282
|
+
|
2283
|
+
if cond.item():
|
2284
|
+
return true_fn()
|
2285
|
+
return false_fn()
|
2286
|
+
|
2287
|
+
|
2288
|
+
def cond_n(cond_val: Tensor, *fns):
|
2289
|
+
fns = list(fns)
|
2290
|
+
fn = fns.pop(0)
|
2291
|
+
if not fns:
|
2292
|
+
return fn
|
2293
|
+
return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
|
2294
|
+
|
2295
|
+
|
2296
|
+
@decorator_knowngood
|
2297
|
+
def _psgd_precond_update_(
|
2298
|
+
matmuled: List[Optional[Tensor]],
|
2299
|
+
Q: "TriuOrLine",
|
2300
|
+
running_lower_bound: List[Tensor],
|
2301
|
+
lower_bount_beta: Tensor,
|
2302
|
+
precond_lr: Tensor,
|
2303
|
+
store_triu_as_line: bool,
|
2304
|
+
power_iter: int,
|
2305
|
+
):
|
2306
|
+
for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
|
2307
|
+
if isinstance(oq, tuple):
|
2308
|
+
oq = oq[1]
|
2309
|
+
|
2310
|
+
q = promote(oq)
|
2311
|
+
if update.ndim < 2:
|
2312
|
+
lb = update.norm(float("inf"))
|
1744
2313
|
else:
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
1748
|
-
|
1749
|
-
|
2314
|
+
lb = max_singular_value(update, power_iter=power_iter)
|
2315
|
+
update = promote(update)
|
2316
|
+
if store_triu_as_line:
|
2317
|
+
update = triu_to_line([update])[0][1]
|
2318
|
+
|
2319
|
+
lb = promote(lb)
|
2320
|
+
lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
|
2321
|
+
copy_stochastic_(lb_state, lb)
|
2322
|
+
copy_stochastic_(oq, q - update / lb * precond_lr)
|
2323
|
+
|
2324
|
+
|
2325
|
+
@decorator_knowngood
|
2326
|
+
def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
|
2327
|
+
"""
|
2328
|
+
I: Identity
|
2329
|
+
U: Update / gg / target
|
2330
|
+
Q: q, preconditioner
|
2331
|
+
scale: scalar scale
|
2332
|
+
---
|
2333
|
+
U = T * scale - I
|
2334
|
+
F = I - U # = 2I - U * scale
|
2335
|
+
O = F @ Q @ F - Q
|
2336
|
+
"""
|
2337
|
+
out = []
|
2338
|
+
for gg, q in zip(GG, Q):
|
2339
|
+
if gg.ndim < 2:
|
2340
|
+
scale = max(1, gg.numel()) / numel
|
2341
|
+
target = promote(gg)
|
2342
|
+
update = target * scale - 1
|
2343
|
+
out.append(q - (1 - update) * q * (1 - update))
|
1750
2344
|
else:
|
1751
|
-
|
2345
|
+
scale = gg.size(0) / numel
|
2346
|
+
gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
|
2347
|
+
update = q - gg @ q @ gg
|
2348
|
+
out.append(update + update.T) # make matrix symmetric
|
2349
|
+
return out
|
2350
|
+
|
2351
|
+
|
2352
|
+
@decorator
|
2353
|
+
def inverse_free_psgd_update_precond(
|
2354
|
+
G: Tensor,
|
2355
|
+
precond_lr: float,
|
2356
|
+
oq: List[Tensor],
|
2357
|
+
store_triu_as_line: bool,
|
2358
|
+
velocity: Optional[List[Tensor]],
|
2359
|
+
beta2: float,
|
2360
|
+
ortho_method: Optional[str],
|
2361
|
+
V: None,
|
2362
|
+
running_lower_bound: List[Tensor],
|
2363
|
+
lower_bount_beta: float,
|
2364
|
+
power_iter: int,
|
2365
|
+
) -> Tensor:
|
2366
|
+
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
2367
|
+
assert V is None
|
2368
|
+
assert ortho_method is None
|
2369
|
+
assert velocity is None
|
2370
|
+
del V, ortho_method, velocity
|
2371
|
+
|
2372
|
+
Q = _balance_to_triu(oq, True)
|
2373
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2374
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2375
|
+
|
2376
|
+
G = psgd_precond_grad(G, Q)
|
2377
|
+
terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
|
2378
|
+
matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
|
2379
|
+
_psgd_precond_update_(
|
2380
|
+
matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2381
|
+
)
|
2382
|
+
return G
|
1752
2383
|
|
1753
2384
|
|
1754
2385
|
@decorator_knowngood
|
1755
|
-
def
|
1756
|
-
|
1757
|
-
x =
|
1758
|
-
norm =
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
2386
|
+
def _clip(x, norm, clip_at, eps=1e-8):
|
2387
|
+
x32 = promote(x)
|
2388
|
+
# (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
|
2389
|
+
norm = clip_at / norm.clamp(min=max(clip_at, eps))
|
2390
|
+
x32 = x32 * norm
|
2391
|
+
copy_stochastic_(x, x32)
|
2392
|
+
|
2393
|
+
|
2394
|
+
@decorator_knowngood
|
2395
|
+
def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
|
2396
|
+
for x in xs:
|
2397
|
+
_clip(x, promote(x).norm(), clip_at, eps)
|
1762
2398
|
|
1763
2399
|
|
1764
2400
|
def l2_normalization_(x, clip_at: float = 1e-8):
|
1765
2401
|
x = list_guard(x)
|
1766
|
-
|
2402
|
+
_compilable_l2_clip_(x, clip_at)
|
2403
|
+
return x
|
1767
2404
|
|
1768
2405
|
|
1769
2406
|
def l2_clip_(x, clip_at: float = 1.0):
|
1770
2407
|
x = list_guard(x)
|
1771
|
-
|
2408
|
+
_compilable_l2_clip_(x, clip_at)
|
2409
|
+
return x
|
1772
2410
|
|
1773
2411
|
|
1774
2412
|
@decorator_knowngood
|
1775
|
-
def _compilable_rmsnorm_clip_(
|
1776
|
-
x
|
1777
|
-
|
1778
|
-
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1779
|
-
torch._foreach_maximum_(norm, clip_at)
|
1780
|
-
return torch._foreach_div(x, norm)
|
2413
|
+
def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
|
2414
|
+
for x in xs:
|
2415
|
+
_clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
|
1781
2416
|
|
1782
2417
|
|
1783
2418
|
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
1784
2419
|
x = list_guard(x)
|
1785
|
-
|
2420
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2421
|
+
return x
|
2422
|
+
|
2423
|
+
|
2424
|
+
@decorator_knowngood
|
2425
|
+
def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
|
2426
|
+
norm = 0
|
2427
|
+
numel = sum([i.numel() for i in x])
|
2428
|
+
for i in x:
|
2429
|
+
norm += promote(i).square().sum()
|
2430
|
+
norm = (norm / numel) ** 0.5
|
2431
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2432
|
+
stochastic_multiply_(x, scalar)
|
2433
|
+
|
2434
|
+
|
2435
|
+
def global_rmsnorm_clip(x, clip_at: float = 1.0):
|
2436
|
+
x = list_guard(x)
|
2437
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2438
|
+
_compilable_global_rmsnorm_clip_(x, clip_at)
|
2439
|
+
return x
|
2440
|
+
|
2441
|
+
|
2442
|
+
@decorator_knowngood
|
2443
|
+
def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
|
2444
|
+
norm = 0
|
2445
|
+
for i in x:
|
2446
|
+
norm += promote(i).square().sum()
|
2447
|
+
norm = norm**0.5
|
2448
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2449
|
+
stochastic_multiply_(x, scalar)
|
2450
|
+
|
2451
|
+
|
2452
|
+
def global_l2norm_clip(x, clip_at: float = 1.0):
|
2453
|
+
x = list_guard(x)
|
2454
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2455
|
+
_compilable_global_l2norm_clip_(x, clip_at)
|
2456
|
+
return x
|
1786
2457
|
|
1787
2458
|
|
1788
2459
|
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1789
2460
|
x = list_guard(x)
|
1790
|
-
|
2461
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2462
|
+
return x
|
1791
2463
|
|
1792
2464
|
|
1793
2465
|
@decorator_knowngood
|
@@ -1920,35 +2592,25 @@ def triu_to_line(Q_list: List[Tensor]):
|
|
1920
2592
|
if q.dim() < 2:
|
1921
2593
|
out.append((None, q))
|
1922
2594
|
else:
|
1923
|
-
out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
|
2595
|
+
out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
|
1924
2596
|
return out
|
1925
2597
|
|
1926
2598
|
|
1927
|
-
|
1928
|
-
|
1929
|
-
assert n * (n + 1) == 2 * numel
|
1930
|
-
return n, n
|
1931
|
-
|
1932
|
-
|
1933
|
-
@decorator
|
1934
|
-
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
|
2599
|
+
@decorator_knowngood
|
2600
|
+
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
|
1935
2601
|
new = []
|
1936
2602
|
for shape, q in Q_list:
|
1937
2603
|
if shape is not None:
|
1938
|
-
|
1939
|
-
|
1940
|
-
x
|
1941
|
-
|
2604
|
+
x, y = torch.triu_indices(*shape, device=q.device)
|
2605
|
+
q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
|
2606
|
+
q_mat[x, y] = q
|
2607
|
+
if symmetric_output:
|
2608
|
+
q_mat[y, x] = q
|
2609
|
+
q = q_mat
|
1942
2610
|
new.append(q)
|
1943
2611
|
return new
|
1944
2612
|
|
1945
2613
|
|
1946
|
-
def update_triu_(q_state, materialised):
|
1947
|
-
for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
|
1948
|
-
assert shape0 == shape1
|
1949
|
-
copy_stochastic_(q, m)
|
1950
|
-
|
1951
|
-
|
1952
2614
|
_warned = set()
|
1953
2615
|
|
1954
2616
|
|
@@ -1971,52 +2633,118 @@ def psgd_should_update(
|
|
1971
2633
|
return int(group[name]) > int(cumulative_prob)
|
1972
2634
|
|
1973
2635
|
|
2636
|
+
@functools.lru_cache(maxsize=None)
|
2637
|
+
def cached_precond_grad_expr(Q_dim, grad_dim):
|
2638
|
+
expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
|
2639
|
+
expr = ",".join(expr)
|
2640
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2641
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2642
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
2643
|
+
|
2644
|
+
|
1974
2645
|
@decorator_knowngood
|
1975
2646
|
def precond_grad_cached_(
|
1976
|
-
|
2647
|
+
ea: Tensor,
|
2648
|
+
cached_q: List[Tensor],
|
2649
|
+
caution: bool = False,
|
2650
|
+
grad: Optional[Tensor] = None,
|
2651
|
+
cast: bool = True,
|
1977
2652
|
):
|
1978
2653
|
if caution:
|
1979
2654
|
ea = _compilable_cautioning(grad, ea)
|
1980
2655
|
md = min_dtype(list(cached_q) + [ea])
|
1981
2656
|
args = [q.to(md) for q in cached_q]
|
1982
2657
|
args = args + [ea.to(md)]
|
1983
|
-
|
2658
|
+
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
|
2659
|
+
new = compiled_einsum(expr, *args)
|
1984
2660
|
if cast:
|
1985
2661
|
return new.to(ea.dtype)
|
1986
2662
|
return new
|
1987
2663
|
|
1988
2664
|
|
2665
|
+
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
|
2666
|
+
|
2667
|
+
|
1989
2668
|
@decorator_knowngood
|
1990
|
-
def _compilable_fused_precond_grad_cached_(
|
1991
|
-
precond = precond_grad_cached_(
|
2669
|
+
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2670
|
+
precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
|
1992
2671
|
update_param_(param, precond, lr, decay, caution=False)
|
1993
2672
|
|
1994
2673
|
|
1995
|
-
def fused_precond_grad_cached_(
|
1996
|
-
lr = scalar_guard(lr, param[0])
|
1997
|
-
_compilable_fused_precond_grad_cached_(
|
2674
|
+
def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2675
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2676
|
+
_compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
|
2677
|
+
|
2678
|
+
|
2679
|
+
@functools.lru_cache(maxsize=None)
|
2680
|
+
def precond_grad_expr(Q_dim, grad_dim):
|
2681
|
+
expr = [
|
2682
|
+
f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
|
2683
|
+
]
|
2684
|
+
expr = ",".join(expr)
|
2685
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2686
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2687
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
1998
2688
|
|
1999
2689
|
|
2000
2690
|
@decorator_knowngood
|
2001
|
-
def psgd_precond_grad(
|
2691
|
+
def psgd_precond_grad(
|
2692
|
+
ea: Tensor,
|
2693
|
+
preconds: TriuOrLine,
|
2694
|
+
caution: bool = False,
|
2695
|
+
grad: Optional[Tensor] = None,
|
2696
|
+
store_triu_as_line: bool = False,
|
2697
|
+
symmetric_output: bool = False,
|
2698
|
+
):
|
2002
2699
|
if caution:
|
2003
2700
|
ea = _compilable_cautioning(grad, ea)
|
2701
|
+
if store_triu_as_line:
|
2702
|
+
preconds = line_to_triu(preconds, symmetric_output)
|
2004
2703
|
md = min_dtype(list(preconds) + [ea])
|
2005
2704
|
args = [q.to(md) for q in preconds]
|
2006
|
-
|
2007
|
-
new =
|
2705
|
+
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
|
2706
|
+
new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
|
2008
2707
|
return new.to(ea.dtype)
|
2009
2708
|
|
2010
2709
|
|
2011
2710
|
@decorator_knowngood
|
2012
|
-
def _compilable_fused_psgd_precond_grad(
|
2013
|
-
|
2711
|
+
def _compilable_fused_psgd_precond_grad(
|
2712
|
+
ea: Tensor,
|
2713
|
+
param,
|
2714
|
+
lr,
|
2715
|
+
grad,
|
2716
|
+
decay,
|
2717
|
+
caution,
|
2718
|
+
preconds: TriuOrLine,
|
2719
|
+
store_triu_as_line: bool = False,
|
2720
|
+
symmetric_output: bool = False,
|
2721
|
+
):
|
2722
|
+
precond = psgd_precond_grad(
|
2723
|
+
ea,
|
2724
|
+
preconds,
|
2725
|
+
caution=caution,
|
2726
|
+
grad=grad,
|
2727
|
+
store_triu_as_line=store_triu_as_line,
|
2728
|
+
symmetric_output=symmetric_output,
|
2729
|
+
)
|
2014
2730
|
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
2015
2731
|
|
2016
2732
|
|
2017
|
-
def fused_psgd_precond_grad(
|
2018
|
-
|
2019
|
-
|
2733
|
+
def fused_psgd_precond_grad(
|
2734
|
+
ea: Tensor,
|
2735
|
+
param,
|
2736
|
+
lr,
|
2737
|
+
grad,
|
2738
|
+
decay,
|
2739
|
+
caution,
|
2740
|
+
preconds: TriuOrLine,
|
2741
|
+
store_triu_as_line: bool = False,
|
2742
|
+
symmetric_output: bool = False,
|
2743
|
+
):
|
2744
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2745
|
+
_compilable_fused_psgd_precond_grad(
|
2746
|
+
ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
|
2747
|
+
)
|
2020
2748
|
|
2021
2749
|
|
2022
2750
|
@decorator_knowngood
|
@@ -2068,7 +2796,15 @@ def caution(g, update):
|
|
2068
2796
|
return _compilable_cautioning(g, update)
|
2069
2797
|
|
2070
2798
|
|
2071
|
-
def
|
2799
|
+
def _inner_precond_update_prob_schedule(
|
2800
|
+
n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2801
|
+
):
|
2802
|
+
return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
|
2803
|
+
|
2804
|
+
|
2805
|
+
def precond_update_prob_schedule(
|
2806
|
+
max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2807
|
+
):
|
2072
2808
|
"""Anneal preconditioner update probability during beginning of training.
|
2073
2809
|
|
2074
2810
|
PSGD benefits from more preconditioner updates at the beginning of training,
|
@@ -2079,11 +2815,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
|
|
2079
2815
|
`min_prob` by ~4000 steps. Default settings work very well for most models and
|
2080
2816
|
training regimes.
|
2081
2817
|
"""
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
return _schedule
|
2818
|
+
return functools.partial(
|
2819
|
+
_inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
|
2820
|
+
)
|
2087
2821
|
|
2088
2822
|
|
2089
2823
|
def merge_group(group, *tensors):
|
@@ -2217,3 +2951,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
|
2217
2951
|
def disable_caution_scaling():
|
2218
2952
|
global _compilable_cautioning
|
2219
2953
|
_compilable_cautioning = _compilable_caution_no_scale
|
2954
|
+
|
2955
|
+
|
2956
|
+
@decorator_knowngood
|
2957
|
+
def sam_step(parameters, ball_size, adaptive: bool = True):
|
2958
|
+
old_params = []
|
2959
|
+
for p in parameters:
|
2960
|
+
old_params.append(p.detach().clone())
|
2961
|
+
grad = promote(p.grad)
|
2962
|
+
if adaptive:
|
2963
|
+
grad = grad * promote(p).square()
|
2964
|
+
stochastic_add_(p.data, grad, ball_size)
|
2965
|
+
p.grad.zero_()
|
2966
|
+
return old_params
|