heavyball 1.7.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +276 -37
- heavyball/chainable.py +419 -206
- heavyball/helpers.py +808 -0
- heavyball/utils.py +1105 -305
- heavyball-2.0.0.dist-info/METADATA +122 -0
- heavyball-2.0.0.dist-info/RECORD +9 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dist-info}/WHEEL +1 -1
- heavyball/optimizations/__init__.py +0 -38
- heavyball/optimizations/integrator.py +0 -169
- heavyball/optimizations/optimizations.py +0 -329
- heavyball-1.7.1.dist-info/METADATA +0 -939
- heavyball-1.7.1.dist-info/RECORD +0 -11
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.1.dist-info → heavyball-2.0.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,28 +1,29 @@
|
|
1
|
+
import collections
|
1
2
|
import contextlib
|
2
3
|
import functools
|
3
4
|
import gc
|
4
5
|
import inspect
|
5
6
|
import math
|
7
|
+
import pickle
|
6
8
|
import random
|
7
9
|
import re
|
8
10
|
import string
|
9
11
|
import warnings
|
10
|
-
from typing import Callable, List, Optional, Tuple, Union
|
12
|
+
from typing import Callable, List, Literal, Optional, Tuple, Union
|
11
13
|
|
12
14
|
import numpy as np
|
13
15
|
import torch
|
14
16
|
from torch import Tensor
|
15
|
-
from torch._dynamo import config
|
16
17
|
from torch._dynamo.exc import TorchDynamoException
|
17
18
|
from torch.backends import cudnn, opt_einsum
|
19
|
+
from torch.nn import functional as F
|
18
20
|
from torch.utils._pytree import tree_map
|
19
21
|
|
20
|
-
config.cache_size_limit = 2**16
|
21
|
-
|
22
22
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
23
|
dynamic = False
|
24
24
|
compile_mode_recommended_to_none = None
|
25
|
-
zeroth_power_mode = "
|
25
|
+
zeroth_power_mode = "newtonschulz"
|
26
|
+
precise_zeroth_power_mode = "qr" # or svd
|
26
27
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
28
|
_cudnn_double_backward_pattern = re.compile(
|
28
29
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -50,7 +51,7 @@ def decorator(func):
|
|
50
51
|
return _fn
|
51
52
|
|
52
53
|
|
53
|
-
def decorator_knowngood(func: Callable):
|
54
|
+
def decorator_knowngood(func: Callable, fullgraph: bool = True):
|
54
55
|
compiled = None
|
55
56
|
|
56
57
|
@functools.wraps(func)
|
@@ -59,7 +60,7 @@ def decorator_knowngood(func: Callable):
|
|
59
60
|
return func(*args, **kwargs)
|
60
61
|
nonlocal compiled
|
61
62
|
if compiled is None:
|
62
|
-
compiled = torch.compile(fullgraph=
|
63
|
+
compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
|
63
64
|
return compiled(*args, **kwargs)
|
64
65
|
|
65
66
|
return _fn
|
@@ -68,6 +69,16 @@ def decorator_knowngood(func: Callable):
|
|
68
69
|
einsum_base = string.ascii_lowercase
|
69
70
|
|
70
71
|
|
72
|
+
@decorator_knowngood
|
73
|
+
def compiled_einsum(expr, *args):
|
74
|
+
"""
|
75
|
+
this is necessary to avoid the slowdown introduced by uncompiled einsum
|
76
|
+
uncompiled einsum is twice as slow if we add three 1-sized dimensions
|
77
|
+
for more, see https://gist.github.com/ClashLuke/a9530f1b9ba4e525369e2dba48528957
|
78
|
+
"""
|
79
|
+
return torch.einsum(expr, *args)
|
80
|
+
|
81
|
+
|
71
82
|
@decorator_knowngood
|
72
83
|
def _compilable_schedule_free_(
|
73
84
|
p: List[Tensor],
|
@@ -122,6 +133,47 @@ def schedule_free_(
|
|
122
133
|
return weight_sum
|
123
134
|
|
124
135
|
|
136
|
+
@decorator_knowngood
|
137
|
+
def _compilable_msam(
|
138
|
+
lr: Tensor,
|
139
|
+
beta1: Tensor,
|
140
|
+
param: List[Tensor],
|
141
|
+
z: List[Tensor],
|
142
|
+
update: List[Tensor],
|
143
|
+
grad: List[Tensor],
|
144
|
+
exp_avg: List[Tensor],
|
145
|
+
caution: bool,
|
146
|
+
decay: Tensor,
|
147
|
+
sam_step_size: Tensor,
|
148
|
+
):
|
149
|
+
exp_avg32 = _lerp(exp_avg, update, beta1)
|
150
|
+
for u_, g_, z_, p_ in zip(exp_avg32, grad, z, param):
|
151
|
+
u_ = u_.view_as(z_)
|
152
|
+
z32_ = promote(z_)
|
153
|
+
if caution:
|
154
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
155
|
+
z32_ = z32_ * (1 - decay * lr) + u_ * -lr
|
156
|
+
copy_stochastic_(z_, z32_)
|
157
|
+
copy_stochastic_(p_, z32_ + u_ / u_.norm().clamp(min=1e-8) * -sam_step_size)
|
158
|
+
|
159
|
+
|
160
|
+
def msam_(
|
161
|
+
lr: float,
|
162
|
+
beta1: float,
|
163
|
+
param: List[Tensor],
|
164
|
+
z: List[Tensor],
|
165
|
+
update: List[Tensor],
|
166
|
+
grad: List[Tensor],
|
167
|
+
exp_avg: List[Tensor],
|
168
|
+
caution: bool,
|
169
|
+
weight_decay: float,
|
170
|
+
sam_step_size: float,
|
171
|
+
):
|
172
|
+
param, z, update, grad, exp_avg = list_guard(param, z, update, grad, exp_avg)
|
173
|
+
lr, beta1, weight_decay, sam_step_size = scalar_guard(lr, beta1, weight_decay, sam_step_size, exp_avg[0])
|
174
|
+
_compilable_msam(lr, beta1, param, z, update, grad, exp_avg, caution, weight_decay, sam_step_size)
|
175
|
+
|
176
|
+
|
125
177
|
def append_or_extend(base, new):
|
126
178
|
if isinstance(new, list):
|
127
179
|
base.extend(new)
|
@@ -161,7 +213,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
161
213
|
new_shape = [grad.shape[0], *new_shape[::-1]]
|
162
214
|
new_grad = grad.reshape(new_shape)
|
163
215
|
if not split:
|
164
|
-
return new_grad
|
216
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
165
217
|
|
166
218
|
grads = [new_grad]
|
167
219
|
for i, sh in reversed(list(enumerate(new_shape[:]))):
|
@@ -172,7 +224,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
172
224
|
continue
|
173
225
|
grads = [a for g in grads for a in g.split(max_precond_dim, dim=i)]
|
174
226
|
if len(grads) == 1:
|
175
|
-
return new_grad
|
227
|
+
return new_grad.to(memory_format=torch.contiguous_format).contiguous()
|
176
228
|
new_grads = []
|
177
229
|
for g in grads:
|
178
230
|
append_or_extend(new_grads, dim_merger(g, max_precond_dim, split))
|
@@ -189,14 +241,14 @@ def eps_sqrt(item, eps):
|
|
189
241
|
|
190
242
|
@decorator_knowngood
|
191
243
|
def _compilable_exp_avg_sq_(
|
192
|
-
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[
|
244
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: None | List[None | Tensor]
|
193
245
|
):
|
194
246
|
g32 = promote(grad)
|
195
247
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
196
248
|
|
197
249
|
denom = [eps_sqrt(d, eps) for d in s32]
|
198
250
|
|
199
|
-
if out[0] is None:
|
251
|
+
if out is None or out[0] is None:
|
200
252
|
return denom
|
201
253
|
|
202
254
|
copy_stochastic_list_(out, denom)
|
@@ -265,8 +317,8 @@ def adaptive_gradient_clipping_(
|
|
265
317
|
def is_compiling():
|
266
318
|
try:
|
267
319
|
return torch.compiler.is_compiling()
|
268
|
-
except TorchDynamoException:
|
269
|
-
return
|
320
|
+
except (TorchDynamoException, AttributeError):
|
321
|
+
return False
|
270
322
|
|
271
323
|
|
272
324
|
def set_(dst: Tensor, src: Tensor):
|
@@ -279,16 +331,29 @@ def clean():
|
|
279
331
|
|
280
332
|
|
281
333
|
def _ignore_warning(msg):
|
282
|
-
warnings.filterwarnings("ignore", f".*{msg}.*")
|
334
|
+
warnings.filterwarnings("ignore", f".*{re.escape(msg)}.*")
|
283
335
|
|
284
336
|
|
285
|
-
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
337
|
+
def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto-hq"):
|
338
|
+
import opt_einsum as _opt_einsum
|
339
|
+
|
286
340
|
cudnn.benchmark = True
|
287
341
|
cudnn.deterministic = False
|
288
342
|
cudnn.benchmark_limit = benchmark_limit
|
289
343
|
torch.use_deterministic_algorithms(False)
|
290
344
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
291
|
-
opt_einsum.set_flags(True
|
345
|
+
opt_einsum.set_flags(True)
|
346
|
+
if einsum_strategy == "heavyball":
|
347
|
+
opt_einsum.strategy = "auto-hq"
|
348
|
+
choices = _opt_einsum.paths._AUTO_HQ_CHOICES
|
349
|
+
for max_val, fn in ((20, _opt_einsum.paths.dynamic_programming), (64, 512), (128, 256)):
|
350
|
+
if isinstance(fn, int):
|
351
|
+
fn = functools.partial(_opt_einsum.path_random.random_greedy, max_repeats=fn)
|
352
|
+
for i in range(max(choices.keys()), max_val):
|
353
|
+
if i not in choices:
|
354
|
+
choices[i] = fn
|
355
|
+
else:
|
356
|
+
opt_einsum.strategy = einsum_strategy
|
292
357
|
|
293
358
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
294
359
|
_ignore_warning(
|
@@ -297,32 +362,39 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
|
297
362
|
_ignore_warning(
|
298
363
|
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
364
|
)
|
365
|
+
_ignore_warning(
|
366
|
+
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead."
|
367
|
+
)
|
300
368
|
|
301
369
|
|
302
|
-
@
|
370
|
+
@decorator_knowngood
|
303
371
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
304
|
-
assert
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
if G.
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
372
|
+
assert (
|
373
|
+
G.ndim >= 2
|
374
|
+
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
375
|
+
assert steps == 5
|
376
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
377
|
+
if G.size(-2) > G.size(-1):
|
378
|
+
X = X.mT
|
379
|
+
|
380
|
+
stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
|
381
|
+
# Perform the NS iterations
|
382
|
+
for a, b, c in [
|
383
|
+
(4.0848, -6.8946, 2.9270),
|
384
|
+
(3.9505, -6.3029, 2.6377),
|
385
|
+
(3.7418, -5.5913, 2.3037),
|
386
|
+
(2.8769, -3.1427, 1.2046),
|
387
|
+
(2.8366, -3.0525, 1.2012),
|
388
|
+
]:
|
389
|
+
A = X @ X.mT
|
390
|
+
B = (
|
391
|
+
b * A + c * A @ A
|
392
|
+
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
313
393
|
X = a * X + B @ X
|
314
|
-
if G.size(0) > G.size(1):
|
315
|
-
X = X.T
|
316
|
-
return X.to(G.dtype)
|
317
|
-
|
318
394
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
if zeroth_power_mode == "svd":
|
323
|
-
u, _s, v = torch.linalg.svd(x)
|
324
|
-
return u @ v.T
|
325
|
-
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
395
|
+
if G.size(-2) > G.size(-1):
|
396
|
+
X = X.mT
|
397
|
+
return X.to(G.dtype)
|
326
398
|
|
327
399
|
|
328
400
|
@decorator_knowngood
|
@@ -377,7 +449,7 @@ def _compilable_grafting(magnitude, direction):
|
|
377
449
|
|
378
450
|
|
379
451
|
@decorator_knowngood
|
380
|
-
def
|
452
|
+
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
381
453
|
if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
|
382
454
|
y = zeropower_via_newtonschulz5(x, 5)
|
383
455
|
elif mode == "qr":
|
@@ -395,9 +467,16 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
|
395
467
|
y = _compilable_grafting(x, y)
|
396
468
|
else:
|
397
469
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
470
|
+
if out is None:
|
471
|
+
return y
|
472
|
+
|
398
473
|
set_(out, y)
|
399
474
|
|
400
475
|
|
476
|
+
def inplace_orthogonal_(x: Tensor, mode: str | None = None, out: Tensor | None = None, scale_mode: str = "none"):
|
477
|
+
return _compilable_orthogonal_(x, mode or zeroth_power_mode, out, scale_mode)
|
478
|
+
|
479
|
+
|
401
480
|
@decorator_knowngood
|
402
481
|
def _compilable_scatter_set(target, source, index):
|
403
482
|
target[:] = source.contiguous()[index].reshape_as(target)
|
@@ -413,6 +492,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
413
492
|
:param Q: List of current eigenbases (updated in-place to Q_new).
|
414
493
|
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
415
494
|
"""
|
495
|
+
if exp_avg.dim() == 0: # preconditioning doesn't make sense here
|
496
|
+
Q.clear()
|
497
|
+
return
|
498
|
+
|
416
499
|
if isinstance(Q, list) and not Q:
|
417
500
|
return
|
418
501
|
|
@@ -430,10 +513,10 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
430
513
|
q_old = promote(q.data)
|
431
514
|
|
432
515
|
tmp = m @ q_old
|
433
|
-
est_eig =
|
516
|
+
est_eig = compiled_einsum("ij,ij->j", q_old, tmp)
|
434
517
|
sort_idx = torch.argsort(est_eig, descending=True)
|
435
518
|
|
436
|
-
tmp[:, sort_idx]
|
519
|
+
tmp[:, sort_idx] = inplace_orthogonal_(tmp[:, sort_idx], precise_zeroth_power_mode)
|
437
520
|
new_qs.append(tmp)
|
438
521
|
|
439
522
|
if exp_avg is None:
|
@@ -453,7 +536,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
453
536
|
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
454
537
|
|
455
538
|
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
456
|
-
exp_avg_new =
|
539
|
+
exp_avg_new = compiled_einsum(
|
457
540
|
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
541
|
)
|
459
542
|
copy_stochastic_(exp_avg, exp_avg_new)
|
@@ -487,10 +570,16 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
|
|
487
570
|
except torch.OutOfMemoryError:
|
488
571
|
if m.device.type == "cpu":
|
489
572
|
raise
|
490
|
-
|
573
|
+
if torch.cuda.is_available():
|
574
|
+
torch.cuda.synchronize(m.device)
|
575
|
+
clean()
|
576
|
+
m = m.cpu()
|
577
|
+
except RuntimeError as e:
|
578
|
+
if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
|
579
|
+
torch.cuda.synchronize(m.device)
|
580
|
+
clean()
|
491
581
|
m = m.cpu()
|
492
|
-
|
493
|
-
if m.dtype != torch.double:
|
582
|
+
elif m.dtype != torch.double:
|
494
583
|
m = m.double()
|
495
584
|
elif eps < max_eps:
|
496
585
|
eps = eps ** (2 / 3)
|
@@ -568,6 +657,19 @@ def scalar_guard(*args):
|
|
568
657
|
return out
|
569
658
|
|
570
659
|
|
660
|
+
def broadcastable_list_guard(*xs):
|
661
|
+
xs = list_guard(*xs)
|
662
|
+
for x in xs:
|
663
|
+
if isinstance(x[0], Tensor):
|
664
|
+
ref = x[0]
|
665
|
+
break
|
666
|
+
else:
|
667
|
+
raise ValueError("No tensor-valued input given")
|
668
|
+
xs = [x if isinstance(x[0], Tensor) else list_guard(scalar_guard(*x, ref)) for x in xs]
|
669
|
+
max_len = max(len(x) for x in xs)
|
670
|
+
return [x if len(x) > 1 else x * max_len for x in xs]
|
671
|
+
|
672
|
+
|
571
673
|
@decorator_knowngood
|
572
674
|
def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
573
675
|
for x_, y_ in zip(x, y):
|
@@ -576,8 +678,8 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
576
678
|
copy_stochastic_(x_, x32 + y32 * alpha)
|
577
679
|
|
578
680
|
|
579
|
-
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
580
|
-
x, y =
|
681
|
+
def stochastic_add_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1):
|
682
|
+
x, y = broadcastable_list_guard(x, y)
|
581
683
|
alpha = scalar_guard(alpha, x[0])
|
582
684
|
_compilable_stochastic_add_(x, y, alpha)
|
583
685
|
|
@@ -590,8 +692,10 @@ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha:
|
|
590
692
|
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
591
693
|
|
592
694
|
|
593
|
-
def stochastic_add_divide_(
|
594
|
-
x, y =
|
695
|
+
def stochastic_add_divide_(
|
696
|
+
x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, alpha: Union[float, int, Tensor] = 1, divisor: float = 1
|
697
|
+
):
|
698
|
+
x, y = broadcastable_list_guard(x, y)
|
595
699
|
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
700
|
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
701
|
|
@@ -604,8 +708,8 @@ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
604
708
|
copy_stochastic_(x_, x32 * y32)
|
605
709
|
|
606
710
|
|
607
|
-
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
608
|
-
x, y =
|
711
|
+
def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
|
712
|
+
x, y = broadcastable_list_guard(x, y)
|
609
713
|
_compilable_stochastic_multiply_(x, y)
|
610
714
|
|
611
715
|
|
@@ -624,7 +728,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
624
728
|
b = einsum_base[idx]
|
625
729
|
g0 = einsum_base[: grad.dim()]
|
626
730
|
g1 = g0.replace(b, b.upper())
|
627
|
-
outer_product =
|
731
|
+
outer_product = compiled_einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
|
628
732
|
stochastic_lerp_(m, outer_product, 1 - beta)
|
629
733
|
|
630
734
|
|
@@ -706,7 +810,7 @@ def project(grad, Q, back: bool):
|
|
706
810
|
preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
707
811
|
if preconditioners:
|
708
812
|
out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
|
709
|
-
out =
|
813
|
+
out = compiled_einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
|
710
814
|
grad = out.to(grad.dtype)
|
711
815
|
return grad
|
712
816
|
|
@@ -714,24 +818,28 @@ def project(grad, Q, back: bool):
|
|
714
818
|
@contextlib.contextmanager
|
715
819
|
def patch_backward():
|
716
820
|
@contextlib.contextmanager
|
717
|
-
def
|
821
|
+
def patch_module(module):
|
718
822
|
original = module.backward
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
823
|
+
try:
|
824
|
+
signature = inspect.signature(original)
|
825
|
+
|
826
|
+
@functools.wraps(original)
|
827
|
+
def patched_backward(*args, **kwargs):
|
828
|
+
new_kwargs = signature.bind(*args)
|
829
|
+
new_kwargs.apply_defaults()
|
830
|
+
new_kwargs = new_kwargs.arguments
|
831
|
+
new_kwargs.update(kwargs)
|
832
|
+
new_kwargs["create_graph"] = True
|
833
|
+
return original(**new_kwargs)
|
834
|
+
|
835
|
+
module.backward = patched_backward
|
836
|
+
yield
|
837
|
+
finally:
|
838
|
+
module.backward = original
|
839
|
+
|
840
|
+
with contextlib.ExitStack() as stack:
|
841
|
+
stack.enter_context(patch_module(torch.Tensor))
|
842
|
+
stack.enter_context(patch_module(torch.autograd))
|
735
843
|
yield
|
736
844
|
|
737
845
|
|
@@ -743,6 +851,13 @@ class ExactHVPFailed(ValueError):
|
|
743
851
|
pass
|
744
852
|
|
745
853
|
|
854
|
+
use_default = object()
|
855
|
+
|
856
|
+
|
857
|
+
def _tensor_key(x: Tensor):
|
858
|
+
return x.data_ptr(), x.numel(), x.dtype, x.device
|
859
|
+
|
860
|
+
|
746
861
|
class StatefulOptimizer(torch.optim.Optimizer):
|
747
862
|
"""
|
748
863
|
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
@@ -755,7 +870,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
755
870
|
compile_step: bool = False
|
756
871
|
hessian_approx: bool = False
|
757
872
|
precond_schedule: Union[Callable, float, None] = None
|
758
|
-
stochastic_schedule: bool =
|
873
|
+
stochastic_schedule: bool | Literal[use_default] = use_default
|
759
874
|
finite_differences: bool = False
|
760
875
|
fallback_to_finite_differences: bool = True
|
761
876
|
_fallback_enabled: bool = False
|
@@ -765,18 +880,62 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
765
880
|
super().__init__(params, {**defaults, "foreach": foreach})
|
766
881
|
self.use_ema = use_ema
|
767
882
|
self.mapping = {}
|
768
|
-
self.
|
769
|
-
|
883
|
+
self.mapping_inverse = {}
|
884
|
+
|
885
|
+
if self.stochastic_schedule is use_default:
|
886
|
+
stochastic_schedule = None
|
887
|
+
for group in self.param_groups:
|
888
|
+
new = group.get("stochastic_schedule", stochastic_schedule)
|
889
|
+
if stochastic_schedule is not None and new != stochastic_schedule:
|
890
|
+
raise ValueError("All parameter groups must have the same stochastic_schedule.")
|
891
|
+
stochastic_schedule = new
|
892
|
+
self.stochastic_schedule = stochastic_schedule
|
893
|
+
|
894
|
+
self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
|
895
|
+
self.precond_rng = random.Random(0x12312)
|
770
896
|
self._is_preconditioning = None
|
771
897
|
|
772
898
|
if self.hessian_approx and self.compile_step:
|
773
899
|
raise ValueError("Hessian approximation can't be used with compile_step.")
|
774
900
|
|
901
|
+
self.register_state_dict_post_hook(StatefulOptimizer._store_stats)
|
902
|
+
self.register_load_state_dict_pre_hook(StatefulOptimizer._load_stats)
|
903
|
+
|
904
|
+
def _store_stats(self, state_dict: dict[str, any]):
|
905
|
+
state_dict["heavyball"] = {
|
906
|
+
"inner_group": self.inner_group,
|
907
|
+
"precond_rng": pickle.dumps(self.precond_rng),
|
908
|
+
"use_ema": self.use_ema,
|
909
|
+
"ema_decay": self.ema_decay,
|
910
|
+
"compile_step": self.compile_step,
|
911
|
+
"hessian_approx": self.hessian_approx,
|
912
|
+
"precond_schedule": pickle.dumps(self.precond_schedule),
|
913
|
+
"stochastic_schedule": self.stochastic_schedule,
|
914
|
+
"fallback_to_finite_differences": self.fallback_to_finite_differences,
|
915
|
+
"_fallback_enabled": self._fallback_enabled,
|
916
|
+
"hvp_interval": self.hvp_interval,
|
917
|
+
}
|
918
|
+
|
919
|
+
def _load_stats(self, state_dict):
|
920
|
+
sd = state_dict.pop("heavyball", {})
|
921
|
+
for k, v in sd.items():
|
922
|
+
if k in ("precond_rng", "precond_schedule"):
|
923
|
+
v = pickle.loads(v)
|
924
|
+
setattr(self, k, v)
|
925
|
+
|
775
926
|
def get_groups(self, group):
|
776
927
|
return [group]
|
777
928
|
|
778
|
-
|
779
|
-
|
929
|
+
@functools.lru_cache(maxsize=None)
|
930
|
+
def state_(self, arg: Tensor, fail: bool = True):
|
931
|
+
if not fail and arg not in self.mapping:
|
932
|
+
return {}
|
933
|
+
if _tensor_key(arg) not in self.mapping_inverse:
|
934
|
+
self._init_mapping()
|
935
|
+
state_param, index = self.mapping_inverse[_tensor_key(arg)]
|
936
|
+
if state_param not in self.state:
|
937
|
+
self.state[state_param] = collections.defaultdict(dict)
|
938
|
+
return self.state[state_param][index]
|
780
939
|
|
781
940
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
782
941
|
for p, g in zip(p_list, g_list):
|
@@ -786,6 +945,18 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
786
945
|
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
787
946
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
788
947
|
|
948
|
+
def _init_mapping(self, group: dict | None = None):
|
949
|
+
if group is None:
|
950
|
+
for group in self.param_groups:
|
951
|
+
self._init_mapping(group)
|
952
|
+
return
|
953
|
+
|
954
|
+
for p in group["params"]:
|
955
|
+
if p not in self.mapping:
|
956
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
957
|
+
for i, pv in enumerate(p_views):
|
958
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
959
|
+
|
789
960
|
def split_p_and_g_in_group(
|
790
961
|
self,
|
791
962
|
group: dict,
|
@@ -805,10 +976,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
805
976
|
yield p, grad
|
806
977
|
continue
|
807
978
|
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
self.mapping[p] = p_views = merge_group(group, p)
|
979
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
980
|
+
for i, pv in enumerate(p_views):
|
981
|
+
self.mapping_inverse[_tensor_key(pv)] = (p, i)
|
812
982
|
|
813
983
|
vector = getattr(p, "vector", None)
|
814
984
|
hessian_vector = getattr(p, "hessian_vector", None)
|
@@ -957,8 +1127,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
957
1127
|
raise ValueError("Hessian approximation requires a closure.")
|
958
1128
|
return None
|
959
1129
|
|
960
|
-
step = self.
|
961
|
-
if not hessian_approx or step % self.hvp_interval == 0:
|
1130
|
+
step = self.inner_group["total_hvp_steps"] = self.inner_group.get("total_hvp_steps", 0) + 1
|
1131
|
+
if not hessian_approx or (step - 1) % self.hvp_interval == 0: # hvp in 0th step for better precond init
|
962
1132
|
with torch.enable_grad():
|
963
1133
|
loss = closure()
|
964
1134
|
return loss
|
@@ -997,12 +1167,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
997
1167
|
if self.precond_schedule is None:
|
998
1168
|
self._is_preconditioning = False
|
999
1169
|
else:
|
1000
|
-
self._is_preconditioning = psgd_should_update(self.
|
1170
|
+
self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
|
1001
1171
|
loss = self._handle_closure(closure)
|
1002
1172
|
|
1003
1173
|
# we assume that parameters are constant and that there are no excessive recompiles
|
1004
1174
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
1005
1175
|
for group in self.param_groups:
|
1176
|
+
if "param_count" not in group:
|
1177
|
+
group["param_count"] = sum(p.numel() for p in group["params"])
|
1006
1178
|
group["is_preconditioning"] = self._is_preconditioning
|
1007
1179
|
self._step(group)
|
1008
1180
|
if self.use_ema:
|
@@ -1105,7 +1277,7 @@ def fused_adam_(
|
|
1105
1277
|
caution: bool,
|
1106
1278
|
):
|
1107
1279
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
1108
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
1280
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, y[0])
|
1109
1281
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
1110
1282
|
|
1111
1283
|
|
@@ -1184,7 +1356,7 @@ def fused_laprop_(
|
|
1184
1356
|
eps: float = 1e-8,
|
1185
1357
|
):
|
1186
1358
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1187
|
-
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
1359
|
+
beta1, beta2, step, lr, eps, decay = scalar_guard(beta1, beta2, step, lr, eps, decay, exp_avg[0])
|
1188
1360
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
1189
1361
|
|
1190
1362
|
|
@@ -1203,7 +1375,7 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
1203
1375
|
|
1204
1376
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
1205
1377
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
1206
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
1378
|
+
beta1, beta2, step, lr, decay = scalar_guard(beta1, beta2, step, lr, decay, exp_avg[0])
|
1207
1379
|
_fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
1208
1380
|
|
1209
1381
|
|
@@ -1233,11 +1405,15 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
|
1233
1405
|
|
1234
1406
|
|
1235
1407
|
@decorator_knowngood
|
1236
|
-
def stochastic_round_(ref: Tensor, source: Tensor):
|
1237
|
-
if source
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1408
|
+
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
|
1409
|
+
if source is not None:
|
1410
|
+
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
1411
|
+
return source
|
1412
|
+
if ref.dtype != torch.bfloat16:
|
1413
|
+
return source.to(ref.dtype)
|
1414
|
+
else:
|
1415
|
+
source = ref
|
1416
|
+
source = source.float()
|
1241
1417
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
1242
1418
|
result.add_(source.view(dtype=torch.int32))
|
1243
1419
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
@@ -1306,74 +1482,115 @@ def stable_exp(x: Tensor):
|
|
1306
1482
|
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1307
1483
|
|
1308
1484
|
|
1485
|
+
def _lse_mean(x: Tensor, pow: float, eps: float) -> Tensor:
|
1486
|
+
# ln(mean(x ** pow) ** (1 / pow / 2))
|
1487
|
+
normalization = math.log(x.numel())
|
1488
|
+
x = x.double()
|
1489
|
+
x = x.abs()
|
1490
|
+
x = x.clamp(min=eps)
|
1491
|
+
x = x.log()
|
1492
|
+
x = x * pow
|
1493
|
+
x = x.flatten()
|
1494
|
+
x = x.logsumexp(dim=0) # log(sum(exp( log(x) * P ) - more stable than sum(x ** P)
|
1495
|
+
x = x - normalization # sum -> mean (divide by x.numel() in log space)
|
1496
|
+
return x / pow / 2
|
1497
|
+
|
1498
|
+
|
1309
1499
|
@decorator_knowngood
|
1310
1500
|
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
1501
|
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
-
|
1313
|
-
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
-
return stable_exp(-log_mean_x_pow / pow / 2)
|
1502
|
+
return stable_exp(-_lse_mean(x, pow, eps))
|
1315
1503
|
|
1316
1504
|
|
1317
1505
|
@decorator_knowngood
|
1318
1506
|
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
1507
|
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
-
|
1321
|
-
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
-
|
1323
|
-
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
-
x_normed = x_normed / pow0 / 2
|
1508
|
+
return stable_exp(_lse_mean(x, pow0, eps) - _lse_mean(y, pow1, eps))
|
1325
1509
|
|
1326
|
-
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
-
y_normed = y_normed / pow1 / 2
|
1328
1510
|
|
1329
|
-
|
1511
|
+
class PrecondInitError(ValueError):
|
1512
|
+
pass
|
1330
1513
|
|
1331
1514
|
|
1332
|
-
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float =
|
1515
|
+
def precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector, scale_max: float = 100):
|
1333
1516
|
automatic_scale = True
|
1334
1517
|
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1518
|
+
scale_scale = 1 if scale_scale is None else scale_scale
|
1519
|
+
|
1335
1520
|
if scale is not None:
|
1336
1521
|
automatic_scale = False
|
1337
1522
|
warn_once(
|
1338
1523
|
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1339
1524
|
)
|
1340
|
-
if scale_scale
|
1525
|
+
if scale_scale != 1:
|
1341
1526
|
warn_once(
|
1342
|
-
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly
|
1527
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly fuse it."
|
1528
|
+
)
|
1529
|
+
if scale_power is not None:
|
1530
|
+
warn_once(
|
1531
|
+
"precond_init_scale_power is used to compute precond_init_scale ** precond_init_scale_power. With a fixed precond_init_scale, you should explicitly fuse it."
|
1343
1532
|
)
|
1344
1533
|
elif hessian_vector is None:
|
1345
1534
|
scale = mean_root(grad, 4) * scale_scale
|
1346
1535
|
else:
|
1347
1536
|
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1537
|
+
|
1538
|
+
if automatic_scale:
|
1539
|
+
scale_power = 0.5 if scale_power is None else scale_power
|
1540
|
+
scale = scale**scale_power
|
1541
|
+
|
1348
1542
|
if isinstance(scale, torch.Tensor):
|
1349
1543
|
scale = scale.item() # slow, but necessary
|
1544
|
+
|
1350
1545
|
if np.isfinite(scale):
|
1351
|
-
if scale > scale_max
|
1546
|
+
if scale > scale_max: # fallthrough to later checks
|
1352
1547
|
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1353
|
-
|
1548
|
+
else:
|
1549
|
+
return scale
|
1550
|
+
|
1354
1551
|
if not automatic_scale:
|
1355
|
-
raise
|
1552
|
+
raise PrecondInitError("The manually set precond_init_scale is not finite")
|
1356
1553
|
|
1357
1554
|
for x in (grad, hessian_vector, vector):
|
1358
1555
|
if x is None:
|
1359
1556
|
continue
|
1360
|
-
if torch.allclose(x, torch.zeros_like(x))
|
1361
|
-
raise
|
1557
|
+
if torch.allclose(x, torch.zeros_like(x)):
|
1558
|
+
raise PrecondInitError(
|
1559
|
+
f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}"
|
1560
|
+
)
|
1362
1561
|
if not torch.isfinite(x).all().item():
|
1363
|
-
raise
|
1364
|
-
|
1562
|
+
raise PrecondInitError("Grad or HVP is not finite")
|
1563
|
+
|
1564
|
+
if np.isfinite(scale):
|
1565
|
+
return scale
|
1566
|
+
|
1567
|
+
raise PrecondInitError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1365
1568
|
|
1366
1569
|
|
1367
|
-
def init_lra(
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1570
|
+
def init_lra(
|
1571
|
+
grad, param_count, scale, scale_scale, scale_power, rank, hessian_vector, vector, dtype=None, eps: float = 10
|
1572
|
+
):
|
1573
|
+
# "+10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1" from @lixilinx at
|
1574
|
+
# https://github.com/lixilinx/psgd_torch/blob/590cd3f125552998ed20028be096652540e2a200/preconditioned_stochastic_gradient_descent.py#L829C11-L829C14
|
1575
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1576
|
+
uv_scale = (param_count * (rank + eps)) ** -0.5
|
1577
|
+
U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1578
|
+
V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device) * uv_scale
|
1371
1579
|
d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
|
1372
1580
|
return U, V, d
|
1373
1581
|
|
1374
1582
|
|
1375
1583
|
def init_Q_exprs(
|
1376
|
-
grad,
|
1584
|
+
grad,
|
1585
|
+
scale,
|
1586
|
+
scale_scale,
|
1587
|
+
scale_power,
|
1588
|
+
max_size,
|
1589
|
+
min_ndim_triangular,
|
1590
|
+
memory_save_mode,
|
1591
|
+
hessian_vector,
|
1592
|
+
vector,
|
1593
|
+
dtype=None,
|
1377
1594
|
):
|
1378
1595
|
"""
|
1379
1596
|
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
@@ -1382,21 +1599,13 @@ def init_Q_exprs(
|
|
1382
1599
|
precond init scale computation from
|
1383
1600
|
https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
|
1384
1601
|
"""
|
1385
|
-
scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
|
1386
|
-
letters = string.ascii_lowercase + string.ascii_uppercase
|
1602
|
+
scale = precond_init_scale(scale, scale_scale, scale_power, grad, hessian_vector, vector)
|
1387
1603
|
dtype = dtype if dtype is not None else grad.dtype
|
1388
1604
|
shape = grad.shape
|
1389
1605
|
|
1390
1606
|
if len(shape) == 0: # scalar
|
1391
1607
|
Q = [scale * torch.ones_like(grad, dtype=dtype)]
|
1392
|
-
|
1393
|
-
exprGs = [",->"]
|
1394
|
-
exprP = ",,->"
|
1395
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1396
|
-
|
1397
|
-
# Tensor
|
1398
|
-
if len(shape) > 13:
|
1399
|
-
raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
|
1608
|
+
return Q
|
1400
1609
|
|
1401
1610
|
scale = scale ** (1 / len(shape))
|
1402
1611
|
|
@@ -1409,6 +1618,9 @@ def init_Q_exprs(
|
|
1409
1618
|
sorted_shape = sorted(shape)
|
1410
1619
|
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1411
1620
|
dim_diag[_max_idx(shape)] = True
|
1621
|
+
elif memory_save_mode == "one_triu":
|
1622
|
+
shape_ranks = np.argsort(np.argsort(shape)) # ranks
|
1623
|
+
dim_diag = (shape_ranks != 0).tolist() # only triu the smallest
|
1412
1624
|
elif memory_save_mode == "all_diag":
|
1413
1625
|
dim_diag = [True for _ in shape]
|
1414
1626
|
else:
|
@@ -1418,66 +1630,90 @@ def init_Q_exprs(
|
|
1418
1630
|
)
|
1419
1631
|
|
1420
1632
|
Q = []
|
1421
|
-
piece1A, piece2A, piece3A = ([], "", "")
|
1422
|
-
exprGs = []
|
1423
|
-
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
|
1424
1633
|
for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
|
1425
1634
|
if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
|
1426
1635
|
# use diagonal matrix as preconditioner for this dim
|
1427
1636
|
Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
|
1428
|
-
|
1429
|
-
piece1A.append(letters[i])
|
1430
|
-
piece2A = piece2A + letters[i]
|
1431
|
-
piece3A = piece3A + letters[i]
|
1432
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1433
|
-
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1434
|
-
exprGs.append(subscripts)
|
1435
|
-
piece1P.append(letters[i + 13])
|
1436
|
-
piece2P.append(letters[i + 13])
|
1437
|
-
piece3P = piece3P + letters[i + 13]
|
1438
|
-
piece4P = piece4P + letters[i + 13]
|
1439
1637
|
else:
|
1440
1638
|
# use triangular matrix as preconditioner for this dim
|
1441
1639
|
Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
|
1442
|
-
|
1443
|
-
piece2A = piece2A + letters[i + 13]
|
1444
|
-
piece3A = piece3A + letters[i]
|
1445
|
-
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1446
|
-
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1447
|
-
subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
|
1448
|
-
exprGs.append(subscripts)
|
1449
|
-
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1450
|
-
piece1P.append(a + b)
|
1451
|
-
piece2P.append(a + c)
|
1452
|
-
piece3P = piece3P + c
|
1453
|
-
piece4P = piece4P + b
|
1454
|
-
|
1455
|
-
exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
1456
|
-
exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
1457
|
-
return [Q, (exprA, tuple(exprGs), exprP)]
|
1640
|
+
return Q
|
1458
1641
|
|
1459
1642
|
|
1460
|
-
@
|
1461
|
-
def psgd_balance_Q(
|
1462
|
-
norms =
|
1463
|
-
geometric_mean = norms
|
1464
|
-
|
1465
|
-
|
1643
|
+
@decorator_knowngood
|
1644
|
+
def psgd_balance_Q(Q):
|
1645
|
+
norms = [promote(q.norm(float("inf"))).log() for q in Q]
|
1646
|
+
geometric_mean = sum([n for n in norms]) / len(Q)
|
1647
|
+
for q, n in zip(Q, norms):
|
1648
|
+
q *= (geometric_mean - n).exp()
|
1466
1649
|
|
1467
1650
|
|
1468
|
-
@
|
1469
|
-
def
|
1470
|
-
u_norm =
|
1471
|
-
v_norm =
|
1472
|
-
scale = (u_norm / v_norm) ** 0.
|
1473
|
-
|
1474
|
-
|
1651
|
+
@decorator_knowngood
|
1652
|
+
def _lra_flatten_and_balance(U: List[Tensor], V: List[Tensor], d: List[Tensor]):
|
1653
|
+
u_norm = sum(u.square().sum().double() for u in U)
|
1654
|
+
v_norm = sum(v.square().sum().double() for v in V)
|
1655
|
+
scale = (u_norm / v_norm) ** 0.25 # sqrt of L2 norms; sqrt, as it's 2 factors
|
1656
|
+
scale = torch.where(torch.logical_and(torch.isfinite(scale), scale > 1e-6), scale, 1)
|
1657
|
+
stochastic_multiply_(U, [1 / scale] * len(U))
|
1658
|
+
stochastic_multiply_(V, [scale] * len(V))
|
1659
|
+
return multi_flatten((U, 1), (V, 1), (d, 0))
|
1475
1660
|
|
1476
1661
|
|
1477
1662
|
@decorator
|
1478
1663
|
def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
1479
1664
|
dtype = min_dtype([U, V, x])
|
1480
|
-
return x +
|
1665
|
+
return x + compiled_einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1666
|
+
|
1667
|
+
|
1668
|
+
@decorator_knowngood
|
1669
|
+
def _compilable_d_step(
|
1670
|
+
d: Tensor,
|
1671
|
+
d_orig: List[Tensor],
|
1672
|
+
invQtv: Tensor,
|
1673
|
+
vector: Tensor,
|
1674
|
+
inverse_precond_vector: Tensor,
|
1675
|
+
hessian_vector: Tensor,
|
1676
|
+
precond_hessian_vector: Tensor,
|
1677
|
+
eps: Tensor,
|
1678
|
+
step: Tensor,
|
1679
|
+
delayed: bool,
|
1680
|
+
):
|
1681
|
+
precond_hessian_vector = promote(precond_hessian_vector)
|
1682
|
+
hessian_vector = promote(hessian_vector)
|
1683
|
+
vector = promote(vector)
|
1684
|
+
inverse_precond_vector = promote(inverse_precond_vector)
|
1685
|
+
invQtv = promote(invQtv)
|
1686
|
+
inverse_precond_vector = invQtv - inverse_precond_vector
|
1687
|
+
|
1688
|
+
nablaD = promote(d).square() * precond_hessian_vector * hessian_vector - vector * inverse_precond_vector
|
1689
|
+
|
1690
|
+
"""
|
1691
|
+
1) Sketching
|
1692
|
+
1.1) multiply, square, etc. in high precision (to avoid numerical errors + doesn't increase cost)
|
1693
|
+
1.2) reduced-precision selection of largest element (halves memory traffic)
|
1694
|
+
2) Computation
|
1695
|
+
2.1) select relevant indices
|
1696
|
+
2.2) redo 1.1 in double precision for scalar values
|
1697
|
+
2.3) return high-precision normalized step-size
|
1698
|
+
overall, this should REDUCE the cost of the operation compared to baseline (-> less memory traffic) while
|
1699
|
+
improving precision
|
1700
|
+
"""
|
1701
|
+
a0 = promote(d) * precond_hessian_vector
|
1702
|
+
a1 = vector
|
1703
|
+
b0 = inverse_precond_vector / promote(d)
|
1704
|
+
b1 = hessian_vector
|
1705
|
+
|
1706
|
+
divisor = (a0.square() + a1.square()) * (b0.square() + b1.square())
|
1707
|
+
idx = divisor.bfloat16().flatten().argmax()
|
1708
|
+
a = a0.index_select(0, idx).double().square() + a1.index_select(0, idx).double().square()
|
1709
|
+
b = b0.index_select(0, idx).double().square() + b1.index_select(0, idx).double().square()
|
1710
|
+
divisor = (a * b).sqrt().clamp(min=eps)
|
1711
|
+
step = -step / divisor
|
1712
|
+
|
1713
|
+
# fused update(s)
|
1714
|
+
apply_flat_add(d_orig, nablaD, step)
|
1715
|
+
if not delayed:
|
1716
|
+
copy_stochastic_(d, promote(d) - nablaD * step)
|
1481
1717
|
|
1482
1718
|
|
1483
1719
|
def update_lra_precond_(
|
@@ -1489,13 +1725,14 @@ def update_lra_precond_(
|
|
1489
1725
|
eps: float,
|
1490
1726
|
step: float,
|
1491
1727
|
delayed: bool,
|
1728
|
+
precond_u: bool,
|
1492
1729
|
):
|
1493
1730
|
"""
|
1494
1731
|
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1495
1732
|
"""
|
1496
1733
|
U_orig, V_orig, d_orig = U, V, d
|
1497
1734
|
|
1498
|
-
U, V, d =
|
1735
|
+
U, V, d = _lra_flatten_and_balance(U, V, d)
|
1499
1736
|
|
1500
1737
|
dtype = min_dtype([U, V, vector, hessian_vector])
|
1501
1738
|
U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
|
@@ -1503,10 +1740,10 @@ def update_lra_precond_(
|
|
1503
1740
|
eps = scalar_guard(eps, vector)
|
1504
1741
|
|
1505
1742
|
Qh = low_rank_mm(U, V, d * hessian_vector)
|
1506
|
-
Ph =
|
1743
|
+
Ph = low_rank_mm(V, U, Qh)
|
1507
1744
|
rank = U.size(1)
|
1508
1745
|
|
1509
|
-
VtU =
|
1746
|
+
VtU = compiled_einsum("br,bn->rn", V, U) # (rank, rank)
|
1510
1747
|
I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
|
1511
1748
|
IpVtU = I + VtU
|
1512
1749
|
invQtv = vector / d
|
@@ -1524,47 +1761,39 @@ def update_lra_precond_(
|
|
1524
1761
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1525
1762
|
|
1526
1763
|
invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
|
1527
|
-
invPv =
|
1528
|
-
invPv = invPv / d
|
1529
|
-
|
1530
|
-
nablaD = Ph * hessian_vector - vector * invPv
|
1531
|
-
divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
|
1532
|
-
divisor = divisor.add(eps).sqrt().max()
|
1533
|
-
d_step = step / divisor
|
1764
|
+
invPv = U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
|
1534
1765
|
|
1535
|
-
|
1766
|
+
eps, step = scalar_guard(eps, step, vector)
|
1767
|
+
_compilable_d_step(d, d_orig, invQtv, vector, invPv, hessian_vector, Ph, eps, step, delayed)
|
1536
1768
|
|
1537
1769
|
a, b = Qh, invQtv
|
1538
1770
|
|
1539
|
-
precond_u = random.random() < 0.5 # update either U or V, not both at the same time
|
1540
1771
|
precond = V if precond_u else U
|
1541
|
-
atV =
|
1542
|
-
btV =
|
1543
|
-
atVVt =
|
1544
|
-
btVVt =
|
1545
|
-
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()
|
1772
|
+
atV = compiled_einsum("b,br->r", a, precond) # o == one
|
1773
|
+
btV = compiled_einsum("b,br->r", b, precond)
|
1774
|
+
atVVt = compiled_einsum("r,br->b", atV, precond)
|
1775
|
+
btVVt = compiled_einsum("r,br->b", btV, precond)
|
1776
|
+
precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm()).clamp(min=eps)
|
1546
1777
|
if precond_u:
|
1547
|
-
a =
|
1548
|
-
b =
|
1778
|
+
a = compiled_einsum("b,r,rg->bg", a, atV, IpVtU)
|
1779
|
+
b = compiled_einsum("b,r,rg->bg", b, btV, IpVtU)
|
1549
1780
|
else:
|
1550
|
-
a = a +
|
1551
|
-
b = b +
|
1552
|
-
a =
|
1553
|
-
b =
|
1781
|
+
a = a + compiled_einsum("br,r->b", V, atV)
|
1782
|
+
b = b + compiled_einsum("br,r->b", V, btV)
|
1783
|
+
a = compiled_einsum("b,r->br", a, atV)
|
1784
|
+
b = compiled_einsum("b,r->br", b, btV)
|
1554
1785
|
apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
|
1555
|
-
|
1556
1786
|
if not delayed:
|
1557
|
-
stochastic_add_([d], [d * nablaD], -d_step)
|
1558
1787
|
stochastic_add_([U if precond_u else V], [b - a], precond_step)
|
1559
1788
|
return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
|
1560
1789
|
|
1561
1790
|
|
1562
|
-
def lra_precond(U, V, d, g):
|
1791
|
+
def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
|
1563
1792
|
"""
|
1564
1793
|
As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
|
1565
1794
|
"""
|
1566
|
-
|
1567
|
-
return d * low_rank_mm(V, U,
|
1795
|
+
new_g = low_rank_mm(U, V, d * g)
|
1796
|
+
return d * low_rank_mm(V, U, new_g)
|
1568
1797
|
|
1569
1798
|
|
1570
1799
|
@decorator_knowngood
|
@@ -1575,16 +1804,42 @@ def dampen_grad(g: Tensor, damp: float = 2**-13):
|
|
1575
1804
|
|
1576
1805
|
|
1577
1806
|
@decorator_knowngood
|
1578
|
-
def
|
1579
|
-
|
1807
|
+
def _compilable_lra_update_(
|
1808
|
+
params: List[Tensor],
|
1809
|
+
update: List[Tensor],
|
1810
|
+
U: Tensor,
|
1811
|
+
V: Tensor,
|
1812
|
+
d: Tensor,
|
1813
|
+
lr: Tensor,
|
1814
|
+
decay: Tensor,
|
1815
|
+
caution: bool,
|
1816
|
+
grads: List[Tensor],
|
1817
|
+
):
|
1818
|
+
update = lra_precond(U, V, d, flatten(update))
|
1580
1819
|
start = 0
|
1581
1820
|
update = update.flatten()
|
1582
|
-
for p in params:
|
1821
|
+
for p, g in zip(params, grads):
|
1583
1822
|
size = p.numel()
|
1584
|
-
|
1823
|
+
update_param_(p, update[start : start + size].view_as(p), lr, decay, caution, g)
|
1585
1824
|
start += size
|
1586
1825
|
|
1587
1826
|
|
1827
|
+
def apply_lra_update(
|
1828
|
+
params: List[Tensor],
|
1829
|
+
update: Tensor,
|
1830
|
+
U: Tensor,
|
1831
|
+
V: Tensor,
|
1832
|
+
d: Tensor,
|
1833
|
+
lr: float,
|
1834
|
+
decay: float,
|
1835
|
+
caution: bool,
|
1836
|
+
grads: List[Tensor],
|
1837
|
+
):
|
1838
|
+
params, grads = list_guard(params, grads)
|
1839
|
+
lr, decay = scalar_guard(lr, decay, params[0])
|
1840
|
+
_compilable_lra_update_(params, update, U, V, d, lr, decay, caution, grads)
|
1841
|
+
|
1842
|
+
|
1588
1843
|
@decorator_knowngood
|
1589
1844
|
def apply_flat_update(params: List[Tensor], update: Tensor):
|
1590
1845
|
start = 0
|
@@ -1595,6 +1850,12 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
|
|
1595
1850
|
start += size
|
1596
1851
|
|
1597
1852
|
|
1853
|
+
@decorator_knowngood
|
1854
|
+
def zero_(x: List[Tensor]):
|
1855
|
+
for i in x:
|
1856
|
+
i.zero_()
|
1857
|
+
|
1858
|
+
|
1598
1859
|
@decorator_knowngood
|
1599
1860
|
def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
1600
1861
|
start = 0
|
@@ -1620,7 +1881,12 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
1620
1881
|
@decorator_knowngood
|
1621
1882
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1622
1883
|
last_dim = x[0].shape[-remaining:] if remaining else []
|
1623
|
-
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1884
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x if i.numel()], 0)
|
1885
|
+
|
1886
|
+
|
1887
|
+
@decorator_knowngood
|
1888
|
+
def multi_flatten(*xs: Tuple[List[Tensor], int]):
|
1889
|
+
return [flatten(x, i) for x, i in xs]
|
1624
1890
|
|
1625
1891
|
|
1626
1892
|
@decorator_knowngood
|
@@ -1634,107 +1900,566 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
|
1634
1900
|
return flatten(vs), flatten(gs)
|
1635
1901
|
|
1636
1902
|
|
1637
|
-
@decorator_knowngood
|
1638
1903
|
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1639
1904
|
md = min_dtype(args)
|
1640
|
-
return
|
1905
|
+
return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1641
1906
|
|
1642
1907
|
|
1643
|
-
|
1644
|
-
|
1645
|
-
|
1646
|
-
|
1647
|
-
|
1648
|
-
A = casted_einsum(exprA, *Q, G)
|
1649
|
-
for i, q in enumerate(Q):
|
1908
|
+
@decorator_knowngood
|
1909
|
+
def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
|
1910
|
+
triangular_qs = []
|
1911
|
+
conjB = promote(conjB)
|
1912
|
+
for i, q in enumerate(Qs):
|
1650
1913
|
q = promote(q)
|
1651
1914
|
if q.dim() <= 1:
|
1652
|
-
conjB
|
1915
|
+
if conjB.ndim == 0:
|
1916
|
+
conjB = conjB / q
|
1917
|
+
else:
|
1918
|
+
shape = [1] * conjB.ndim
|
1919
|
+
shape[i] = -1
|
1920
|
+
conjB = conjB / q.view(shape)
|
1653
1921
|
else:
|
1654
|
-
|
1655
|
-
|
1656
|
-
|
1657
|
-
|
1922
|
+
triangular_qs.append((i, q))
|
1923
|
+
return triangular_qs, conjB
|
1924
|
+
|
1925
|
+
|
1926
|
+
@decorator_knowngood
|
1927
|
+
def _reshape_conjB(solved: Tensor, transposed_shape: List[int], original_shape: List[int], last_dim: int, new_dim: int):
|
1928
|
+
solved = solved.reshape(transposed_shape)
|
1929
|
+
solved = solved.transpose(-1, last_dim)
|
1930
|
+
solved = solved.reshape(original_shape)
|
1931
|
+
solved = solved.transpose(-1, new_dim)
|
1932
|
+
return solved.contiguous(), solved.shape
|
1933
|
+
|
1934
|
+
|
1935
|
+
def ndim_tuple(Q: list[Tensor]) -> tuple:
|
1936
|
+
return tuple(q.ndim for q in Q)
|
1937
|
+
|
1938
|
+
|
1939
|
+
def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "vector") == randn during hvp/whitening
|
1940
|
+
if conjB is None:
|
1941
|
+
conjB = torch.randn_like(G)
|
1942
|
+
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
|
1943
|
+
A = casted_einsum(exprA, *Q, G)
|
1944
|
+
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
1945
|
+
transposed_shape = original_shape = conjB.shape
|
1946
|
+
prev_i = -1
|
1947
|
+
qs, conjB = _psgd_calc_scalars_(Q, conjB)
|
1948
|
+
for i, tri_q in qs:
|
1949
|
+
conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
|
1950
|
+
prev_i = i
|
1951
|
+
conjB = solve(tri_q, conjB, upper=True, left=False)
|
1952
|
+
conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
|
1658
1953
|
return A, conjB
|
1659
1954
|
|
1660
1955
|
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
x
|
1669
|
-
|
1670
|
-
|
1671
|
-
|
1956
|
+
@decorator_knowngood
|
1957
|
+
def _random_projection(x: Tensor, scale: Optional[Tensor]):
|
1958
|
+
if scale is None:
|
1959
|
+
scale = x.norm(float("inf")).clamp(min=1e-8)
|
1960
|
+
k = 2 ** math.ceil(math.log2(math.log2(min(x.shape)))) # next-largest-power-of-2 of log2-of-size
|
1961
|
+
norm = x.square().sum(0)
|
1962
|
+
indices = torch.topk(norm, k, largest=True).indices
|
1963
|
+
return x.index_select(1, indices).contiguous() / scale, scale
|
1964
|
+
|
1965
|
+
|
1966
|
+
def max_singular_value_exact(A, use_lobpcg: bool = False):
|
1967
|
+
try:
|
1968
|
+
if use_lobpcg:
|
1969
|
+
A = A @ A.T
|
1970
|
+
eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
|
1971
|
+
return eigval[0].sqrt()
|
1972
|
+
else:
|
1973
|
+
return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
|
1974
|
+
except (torch.linalg.LinAlgError, RuntimeError):
|
1975
|
+
return max_singular_value_power_iter(promote(A), iterations=2)
|
1976
|
+
|
1977
|
+
|
1978
|
+
@decorator_knowngood
|
1979
|
+
def max_singular_value_power_iter(A_outer: Tensor, max_abs: Optional[Tensor] = None, iterations: int = 5):
|
1980
|
+
"""
|
1981
|
+
Rayleigh quotient of row with the largest norm + optional power iterations
|
1982
|
+
"""
|
1983
|
+
x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
|
1984
|
+
x_norm = promote(x_norm)
|
1985
|
+
|
1986
|
+
def _inner():
|
1987
|
+
A = A_outer
|
1988
|
+
x = A.index_select(0, max_idx).flatten().contiguous()
|
1989
|
+
A = stochastic_round_(A / x_norm)
|
1990
|
+
x = x / x_norm
|
1991
|
+
|
1992
|
+
def _mv(x):
|
1993
|
+
return promote(A.T.mv(A.mv(stochastic_round_(x))))
|
1994
|
+
|
1995
|
+
for _ in range(iterations):
|
1996
|
+
# A @ A.T @ x, but explicitly telling torch.compile not to compute the full matrix
|
1997
|
+
x = F.normalize(_mv(x), dim=0)
|
1998
|
+
out = (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
|
1999
|
+
return out.squeeze().clone()
|
2000
|
+
|
2001
|
+
return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone())
|
2002
|
+
|
2003
|
+
|
2004
|
+
@decorator_knowngood
|
2005
|
+
def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
|
2006
|
+
"""
|
2007
|
+
Adapted from @evanatyourservice
|
2008
|
+
"""
|
2009
|
+
Y, max_abs = _random_projection(A, max_abs)
|
2010
|
+
Q = inplace_orthogonal_(Y, precise_zeroth_power_mode)
|
2011
|
+
Q = Q / max_abs
|
2012
|
+
Z = A.T @ Q
|
2013
|
+
W = inplace_orthogonal_(Z, precise_zeroth_power_mode)
|
2014
|
+
sketch_norm = max_singular_value_exact(Z.T @ W)
|
2015
|
+
return sketch_norm * max_abs
|
2016
|
+
|
2017
|
+
|
2018
|
+
def _max_singular_value_ndim(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2019
|
+
if A.ndim <= 2:
|
2020
|
+
return max_singular_value(A, max_svd, use_cholesky, power_iter)
|
2021
|
+
|
2022
|
+
base = einsum_base[: A.ndim]
|
2023
|
+
A16 = stochastic_round_(A)
|
2024
|
+
squares = [compiled_einsum(f"{base},{base.replace(b, b.upper())}->{b}{b.upper()}", A16, A16) for b in base]
|
2025
|
+
svds = [max_singular_value(promote(s), max_svd, use_cholesky, power_iter) for s in squares]
|
2026
|
+
svds = torch.stack(svds)
|
2027
|
+
return svds.max().sqrt().to(A.dtype) # sqrt because we took the SVD of a squared matrix
|
2028
|
+
|
2029
|
+
|
2030
|
+
@decorator_knowngood
|
2031
|
+
def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16) -> Tensor:
|
2032
|
+
if A.ndim < 2:
|
2033
|
+
return A.abs().max()
|
2034
|
+
if A.ndim > 2:
|
2035
|
+
raise ValueError("max_singular_value: dimension of A must be less than or equal to 2")
|
2036
|
+
if min(A.shape) <= max_svd:
|
2037
|
+
return max_singular_value_exact(A) # SVD needs ~25% more runtime for size=32, but 0% error instead of 5%
|
2038
|
+
if use_cholesky or power_iter < 0:
|
2039
|
+
return max_singular_value_cholesky(A)
|
2040
|
+
return max_singular_value_power_iter(A, None, iterations=power_iter)
|
2041
|
+
|
2042
|
+
|
2043
|
+
@decorator_knowngood
|
2044
|
+
def clamped_max_singular_value(
|
2045
|
+
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
|
2046
|
+
) -> Tensor:
|
2047
|
+
norm = A.norm() # L2 norm is an upper bound for the spectral norm. If the upper bound is below the minimum, the real value will be too.
|
2048
|
+
out = cond(norm > min, lambda: max_singular_value(A, max_svd, use_cholesky, power_iter), lambda: norm.clone())
|
2049
|
+
return out.clamp(min=min)
|
2050
|
+
|
2051
|
+
|
2052
|
+
@decorator_knowngood
|
2053
|
+
def min_singular_value(
|
2054
|
+
A: Tensor,
|
2055
|
+
power_iter: int = 5,
|
2056
|
+
safety: float = 1.05,
|
2057
|
+
max_svd: int = 32,
|
2058
|
+
):
|
2059
|
+
if A.ndim < 2:
|
2060
|
+
return A.abs().min()
|
2061
|
+
|
2062
|
+
n = A.size(0)
|
2063
|
+
if n <= max_svd:
|
2064
|
+
try:
|
2065
|
+
eigs = torch.linalg.eigvalsh(promote(A))
|
2066
|
+
return eigs.min().to(A.dtype)
|
2067
|
+
except torch.linalg.LinAlgError:
|
2068
|
+
pass
|
2069
|
+
|
2070
|
+
lambda_max_hat = max_singular_value(A, power_iter=power_iter)
|
2071
|
+
lambda_upper = lambda_max_hat * safety
|
2072
|
+
|
2073
|
+
row_norms = A.norm(dim=1)
|
2074
|
+
norm, idx = row_norms.min(dim=0)
|
2075
|
+
v = cond(norm > 0, lambda: A.index_select(0, idx).flatten(), lambda: torch.rand_like(A[0]))
|
2076
|
+
|
2077
|
+
v = v / promote(v.norm())
|
2078
|
+
for _ in range(power_iter):
|
2079
|
+
v = lambda_upper * v - promote(A.mv(stochastic_round_(v)))
|
2080
|
+
v = v / promote(v.norm())
|
2081
|
+
mu_hat = v @ (lambda_upper * v - promote(A.mv(stochastic_round_(v))))
|
2082
|
+
|
2083
|
+
lambda_min_hat = lambda_upper - mu_hat
|
2084
|
+
|
2085
|
+
def _approx():
|
2086
|
+
mu = A.trace() / n
|
2087
|
+
sigma_square = A.square().sum() / n - mu**2
|
2088
|
+
return mu - (sigma_square / (n - 1)).sqrt()
|
2089
|
+
|
2090
|
+
return cond(
|
2091
|
+
(~torch.isfinite(lambda_min_hat)) | (lambda_min_hat <= 0), _approx, lambda: lambda_min_hat.clone()
|
2092
|
+
).squeeze()
|
2093
|
+
|
2094
|
+
|
2095
|
+
@decorator_knowngood
|
2096
|
+
def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
|
2097
|
+
if isinstance(Q[0], tuple):
|
2098
|
+
psgd_balance_Q([o[1] for o in Q])
|
2099
|
+
return line_to_triu(Q, symmetric_output)
|
2100
|
+
psgd_balance_Q(Q)
|
2101
|
+
return Q
|
2102
|
+
|
2103
|
+
|
2104
|
+
@functools.lru_cache(maxsize=None)
|
2105
|
+
def calcG_expr(q_dim, g_dim):
|
2106
|
+
exprs = []
|
2107
|
+
base = einsum_base[:g_dim]
|
2108
|
+
for i, q in enumerate(q_dim):
|
2109
|
+
new = list(base)
|
2110
|
+
if q == 2:
|
2111
|
+
new[i] = "Z"
|
2112
|
+
out = f"{base[i]}Z"
|
2113
|
+
else:
|
2114
|
+
out = base[i]
|
2115
|
+
exprs.append(f"{base},{''.join(new)}->{out}")
|
2116
|
+
return exprs
|
2117
|
+
|
2118
|
+
|
2119
|
+
def eye_like(x: Tensor):
|
2120
|
+
if x.ndim < 2:
|
2121
|
+
return torch.ones_like(x)
|
2122
|
+
assert x.ndim == 2
|
2123
|
+
assert x.size(0) == x.size(1)
|
2124
|
+
return torch.eye(x.size(0), device=x.device, dtype=x.dtype)
|
2125
|
+
|
2126
|
+
|
2127
|
+
@decorator_knowngood
|
2128
|
+
def _gg_inverse_via_vjp(G: Tensor, Q: List[Tensor]):
|
2129
|
+
"""
|
2130
|
+
Idea:
|
2131
|
+
G should be zeroth power. So, all Qs together should approximate the G's inverse.
|
2132
|
+
Assuming G is 2-dimensional, we'd have two preconditioning Q's: L, R
|
2133
|
+
Optimize LGR being a zeroth power using `MSE( (LGR) (LGR).T , I ) + MSE( (LGR).T + (LGR) , I )`,
|
2134
|
+
then backprop to L/R jointly.
|
2135
|
+
This function computes the gradients for L/R, with an outer optimizer layer handling the rest.
|
2136
|
+
|
2137
|
+
`psgd_precond_grad` computes LGR for the general (n-dimensional) case
|
2138
|
+
`exprG` contains the einsum expressions to compute (LGR)(LGR).T (and (LGR).T(LGR)) for the general n-dim case
|
2139
|
+
Args:
|
2140
|
+
G: Gradient that should be orthogonalized
|
2141
|
+
Q: List of preconditioner tensors.
|
2142
|
+
|
2143
|
+
Returns:
|
2144
|
+
- List of gradients with respect to Q (d_Q).
|
2145
|
+
"""
|
2146
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2147
|
+
|
2148
|
+
G16 = stochastic_round_(G)
|
2149
|
+
Q16 = [stochastic_round_(q) for q in Q]
|
2150
|
+
P = psgd_precond_grad(G16, Q16) # Q₀GQ₁
|
2151
|
+
|
2152
|
+
d_P = torch.zeros_like(G)
|
2153
|
+
base = einsum_base[: G.ndim]
|
2154
|
+
for i, exprG in enumerate(exprGs):
|
2155
|
+
pp = compiled_einsum(exprG, P, P)
|
2156
|
+
error = pp - eye_like(pp)
|
2157
|
+
dim = einsum_base[i]
|
2158
|
+
if pp.ndim == 2:
|
2159
|
+
new = dim.upper()
|
2160
|
+
prec = f"{new}{dim}"
|
2161
|
+
else:
|
2162
|
+
new = dim
|
2163
|
+
prec = dim
|
2164
|
+
d_P += torch.einsum(f"{base},{prec}->{base.replace(dim, new)}", P, error)
|
2165
|
+
|
2166
|
+
d_P = stochastic_round_(d_P) # accumulate in fp32 and round at the end
|
2167
|
+
grads = []
|
2168
|
+
for i, exprG in enumerate(exprGs):
|
2169
|
+
new_q = Q16[:]
|
2170
|
+
new_q[i] = eye_like(new_q[i])
|
2171
|
+
pq = psgd_precond_grad(G16, new_q)
|
2172
|
+
grad = compiled_einsum(exprG, pq, d_P)
|
2173
|
+
if grad.ndim == 2:
|
2174
|
+
grad = (grad + grad.T) / 2
|
2175
|
+
grads.append(grad)
|
2176
|
+
|
2177
|
+
return grads, P.to(G.dtype)
|
2178
|
+
|
2179
|
+
|
2180
|
+
def _inverse_initial_guess(gg):
|
2181
|
+
n = gg.shape[0]
|
2182
|
+
|
2183
|
+
sigma_max = promote(gg.norm())
|
2184
|
+
|
2185
|
+
trace_gg = promote(torch.trace(gg))
|
2186
|
+
sigma_min_approx = trace_gg / (n * sigma_max)
|
2187
|
+
|
2188
|
+
return sigma_max, sigma_min_approx
|
2189
|
+
|
2190
|
+
|
2191
|
+
@decorator_knowngood
|
2192
|
+
def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
|
2193
|
+
k = torch.arange(degree, dtype=torch.float64, device=device)
|
2194
|
+
rotation = (2 * k + 1) * math.pi / (2 * degree)
|
2195
|
+
f = (rotation.cos() + 1 + eps) ** -0.5
|
2196
|
+
rotation = (rotation.view(-1, 1) * k[1:].view(1, -1)).cos()
|
2197
|
+
coeff0 = f.sum() / degree
|
2198
|
+
coeffs = f @ rotation * 2 / degree
|
2199
|
+
return coeff0.float(), coeffs.float()
|
2200
|
+
|
2201
|
+
|
2202
|
+
@decorator_knowngood
|
2203
|
+
def _psgd_default_preconditioner_grad(
|
2204
|
+
terms: List[Tuple[Tensor, Tensor]],
|
2205
|
+
Q: List[Tensor],
|
2206
|
+
) -> List[Tensor]:
|
2207
|
+
out = []
|
2208
|
+
for q, (x, y) in zip(Q, terms):
|
2209
|
+
x = promote(x)
|
2210
|
+
y = promote(y)
|
2211
|
+
update = x - y
|
2212
|
+
if q.ndim < 2:
|
2213
|
+
update = q * update
|
2214
|
+
else:
|
2215
|
+
update = (q @ update).triu()
|
2216
|
+
out.append(update)
|
2217
|
+
return out
|
1672
2218
|
|
1673
2219
|
|
1674
2220
|
@decorator
|
1675
|
-
def psgd_update_precond(
|
2221
|
+
def psgd_update_precond(
|
2222
|
+
G: Tensor,
|
2223
|
+
precond_lr: float,
|
2224
|
+
oq: "TriuOrLine",
|
2225
|
+
store_triu_as_line: bool,
|
2226
|
+
velocity: Optional[List[Tensor]],
|
2227
|
+
beta2: float,
|
2228
|
+
ortho_method: Optional[str],
|
2229
|
+
V: Tensor,
|
2230
|
+
running_lower_bound: List[Tensor],
|
2231
|
+
lower_bount_beta: float,
|
2232
|
+
power_iter: int,
|
2233
|
+
) -> None:
|
1676
2234
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1677
|
-
|
1678
|
-
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1684
|
-
|
1685
|
-
|
1686
|
-
|
1687
|
-
|
2235
|
+
Q = _balance_to_triu(oq)
|
2236
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2237
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2238
|
+
|
2239
|
+
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
|
2240
|
+
terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
|
2241
|
+
del A, conjB, V
|
2242
|
+
updates = _psgd_default_preconditioner_grad(terms, Q)
|
2243
|
+
_psgd_precond_update_(
|
2244
|
+
updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2245
|
+
)
|
2246
|
+
return None
|
2247
|
+
|
2248
|
+
|
2249
|
+
@decorator_knowngood
|
2250
|
+
def bf16_matmul(x: Tensor, y: Tensor):
|
2251
|
+
return (promote(x) @ promote(y)).to(x.dtype)
|
2252
|
+
|
2253
|
+
|
2254
|
+
def if_iscompiling(fn):
|
2255
|
+
base = getattr(torch, fn.__name__, None)
|
2256
|
+
|
2257
|
+
def _fn(x):
|
2258
|
+
if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
|
2259
|
+
return base(x)
|
2260
|
+
return fn(x)
|
2261
|
+
|
2262
|
+
return _fn
|
2263
|
+
|
2264
|
+
|
2265
|
+
@if_iscompiling
|
2266
|
+
def while_loop(cond, body, state):
|
2267
|
+
"""
|
2268
|
+
dispatches to torch.while_loop if we're compiling. otherwise, falls back to a naive + slow baseline
|
2269
|
+
useful for debugging
|
2270
|
+
"""
|
2271
|
+
while cond(*state).item():
|
2272
|
+
state = body(*state)
|
2273
|
+
return state
|
2274
|
+
|
2275
|
+
|
2276
|
+
@if_iscompiling
|
2277
|
+
def cond(cond, true_fn, false_fn):
|
2278
|
+
"""
|
2279
|
+
dispatches to torch.cond if we're compiling. otherwise, falls back to a naive + slow baseline
|
2280
|
+
useful for debugging
|
2281
|
+
"""
|
2282
|
+
|
2283
|
+
if cond.item():
|
2284
|
+
return true_fn()
|
2285
|
+
return false_fn()
|
2286
|
+
|
2287
|
+
|
2288
|
+
def cond_n(cond_val: Tensor, *fns):
|
2289
|
+
fns = list(fns)
|
2290
|
+
fn = fns.pop(0)
|
2291
|
+
if not fns:
|
2292
|
+
return fn
|
2293
|
+
return cond(cond_val == 0, fn, lambda: cond_n(cond_val - 1, *fns))
|
2294
|
+
|
2295
|
+
|
2296
|
+
@decorator_knowngood
|
2297
|
+
def _psgd_precond_update_(
|
2298
|
+
matmuled: List[Optional[Tensor]],
|
2299
|
+
Q: "TriuOrLine",
|
2300
|
+
running_lower_bound: List[Tensor],
|
2301
|
+
lower_bount_beta: Tensor,
|
2302
|
+
precond_lr: Tensor,
|
2303
|
+
store_triu_as_line: bool,
|
2304
|
+
power_iter: int,
|
2305
|
+
):
|
2306
|
+
for update, oq, lb_state in zip(matmuled, Q, running_lower_bound):
|
2307
|
+
if isinstance(oq, tuple):
|
2308
|
+
oq = oq[1]
|
2309
|
+
|
2310
|
+
q = promote(oq)
|
2311
|
+
if update.ndim < 2:
|
2312
|
+
lb = update.norm(float("inf"))
|
1688
2313
|
else:
|
1689
|
-
|
1690
|
-
|
1691
|
-
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
2314
|
+
lb = max_singular_value(update, power_iter=power_iter)
|
2315
|
+
update = promote(update)
|
2316
|
+
if store_triu_as_line:
|
2317
|
+
update = triu_to_line([update])[0][1]
|
2318
|
+
|
2319
|
+
lb = promote(lb)
|
2320
|
+
lb = lb.maximum(promote(lb_state) + (lb - promote(lb_state)) * (1 - lower_bount_beta))
|
2321
|
+
copy_stochastic_(lb_state, lb)
|
2322
|
+
copy_stochastic_(oq, q - update / lb * precond_lr)
|
2323
|
+
|
2324
|
+
|
2325
|
+
@decorator_knowngood
|
2326
|
+
def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
|
2327
|
+
"""
|
2328
|
+
I: Identity
|
2329
|
+
U: Update / gg / target
|
2330
|
+
Q: q, preconditioner
|
2331
|
+
scale: scalar scale
|
2332
|
+
---
|
2333
|
+
U = T * scale - I
|
2334
|
+
F = I - U # = 2I - U * scale
|
2335
|
+
O = F @ Q @ F - Q
|
2336
|
+
"""
|
2337
|
+
out = []
|
2338
|
+
for gg, q in zip(GG, Q):
|
2339
|
+
if gg.ndim < 2:
|
2340
|
+
scale = max(1, gg.numel()) / numel
|
2341
|
+
target = promote(gg)
|
2342
|
+
update = target * scale - 1
|
2343
|
+
out.append(q - (1 - update) * q * (1 - update))
|
1696
2344
|
else:
|
1697
|
-
|
1698
|
-
|
2345
|
+
scale = gg.size(0) / numel
|
2346
|
+
gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
|
2347
|
+
update = q - gg @ q @ gg
|
2348
|
+
out.append(update + update.T) # make matrix symmetric
|
2349
|
+
return out
|
2350
|
+
|
2351
|
+
|
2352
|
+
@decorator
|
2353
|
+
def inverse_free_psgd_update_precond(
|
2354
|
+
G: Tensor,
|
2355
|
+
precond_lr: float,
|
2356
|
+
oq: List[Tensor],
|
2357
|
+
store_triu_as_line: bool,
|
2358
|
+
velocity: Optional[List[Tensor]],
|
2359
|
+
beta2: float,
|
2360
|
+
ortho_method: Optional[str],
|
2361
|
+
V: None,
|
2362
|
+
running_lower_bound: List[Tensor],
|
2363
|
+
lower_bount_beta: float,
|
2364
|
+
power_iter: int,
|
2365
|
+
) -> Tensor:
|
2366
|
+
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
2367
|
+
assert V is None
|
2368
|
+
assert ortho_method is None
|
2369
|
+
assert velocity is None
|
2370
|
+
del V, ortho_method, velocity
|
2371
|
+
|
2372
|
+
Q = _balance_to_triu(oq, True)
|
2373
|
+
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
|
2374
|
+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
|
2375
|
+
|
2376
|
+
G = psgd_precond_grad(G, Q)
|
2377
|
+
terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
|
2378
|
+
matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
|
2379
|
+
_psgd_precond_update_(
|
2380
|
+
matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
|
2381
|
+
)
|
2382
|
+
return G
|
2383
|
+
|
2384
|
+
|
2385
|
+
@decorator_knowngood
|
2386
|
+
def _clip(x, norm, clip_at, eps=1e-8):
|
2387
|
+
x32 = promote(x)
|
2388
|
+
# (x / y.clamp(min=eps)).clamp(max=1) == x / y.clamp(min=max(x, eps))
|
2389
|
+
norm = clip_at / norm.clamp(min=max(clip_at, eps))
|
2390
|
+
x32 = x32 * norm
|
2391
|
+
copy_stochastic_(x, x32)
|
1699
2392
|
|
1700
2393
|
|
1701
2394
|
@decorator_knowngood
|
1702
|
-
def _compilable_l2_clip_(
|
1703
|
-
|
1704
|
-
|
1705
|
-
norm = torch._foreach_norm(x)
|
1706
|
-
torch._foreach_maximum_(norm, clip_at)
|
1707
|
-
out = torch._foreach_div(x, norm)
|
1708
|
-
return stochastic_round_list_(ref, out)
|
2395
|
+
def _compilable_l2_clip_(xs, clip_at, eps=1e-8):
|
2396
|
+
for x in xs:
|
2397
|
+
_clip(x, promote(x).norm(), clip_at, eps)
|
1709
2398
|
|
1710
2399
|
|
1711
2400
|
def l2_normalization_(x, clip_at: float = 1e-8):
|
1712
2401
|
x = list_guard(x)
|
1713
|
-
|
2402
|
+
_compilable_l2_clip_(x, clip_at)
|
2403
|
+
return x
|
1714
2404
|
|
1715
2405
|
|
1716
2406
|
def l2_clip_(x, clip_at: float = 1.0):
|
1717
2407
|
x = list_guard(x)
|
1718
|
-
|
2408
|
+
_compilable_l2_clip_(x, clip_at)
|
2409
|
+
return x
|
1719
2410
|
|
1720
2411
|
|
1721
2412
|
@decorator_knowngood
|
1722
|
-
def _compilable_rmsnorm_clip_(
|
1723
|
-
x
|
1724
|
-
|
1725
|
-
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1726
|
-
torch._foreach_maximum_(norm, clip_at)
|
1727
|
-
return torch._foreach_div(x, norm)
|
2413
|
+
def _compilable_rmsnorm_clip_(xs, clip_at, eps=1e-8):
|
2414
|
+
for x in xs:
|
2415
|
+
_clip(x, promote(x).square().mean().sqrt(), clip_at, eps)
|
1728
2416
|
|
1729
2417
|
|
1730
2418
|
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
1731
2419
|
x = list_guard(x)
|
1732
|
-
|
2420
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2421
|
+
return x
|
2422
|
+
|
2423
|
+
|
2424
|
+
@decorator_knowngood
|
2425
|
+
def _compilable_global_rmsnorm_clip_(x, clip_at, eps=1e-8):
|
2426
|
+
norm = 0
|
2427
|
+
numel = sum([i.numel() for i in x])
|
2428
|
+
for i in x:
|
2429
|
+
norm += promote(i).square().sum()
|
2430
|
+
norm = (norm / numel) ** 0.5
|
2431
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2432
|
+
stochastic_multiply_(x, scalar)
|
2433
|
+
|
2434
|
+
|
2435
|
+
def global_rmsnorm_clip(x, clip_at: float = 1.0):
|
2436
|
+
x = list_guard(x)
|
2437
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2438
|
+
_compilable_global_rmsnorm_clip_(x, clip_at)
|
2439
|
+
return x
|
2440
|
+
|
2441
|
+
|
2442
|
+
@decorator_knowngood
|
2443
|
+
def _compilable_global_l2norm_clip_(x, clip_at, eps=1e-8):
|
2444
|
+
norm = 0
|
2445
|
+
for i in x:
|
2446
|
+
norm += promote(i).square().sum()
|
2447
|
+
norm = norm**0.5
|
2448
|
+
scalar = clip_at / norm.clamp(min=max(clip_at, eps))
|
2449
|
+
stochastic_multiply_(x, scalar)
|
2450
|
+
|
2451
|
+
|
2452
|
+
def global_l2norm_clip(x, clip_at: float = 1.0):
|
2453
|
+
x = list_guard(x)
|
2454
|
+
clip_at = scalar_guard(clip_at, x[0])
|
2455
|
+
_compilable_global_l2norm_clip_(x, clip_at)
|
2456
|
+
return x
|
1733
2457
|
|
1734
2458
|
|
1735
2459
|
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1736
2460
|
x = list_guard(x)
|
1737
|
-
|
2461
|
+
_compilable_rmsnorm_clip_(x, clip_at)
|
2462
|
+
return x
|
1738
2463
|
|
1739
2464
|
|
1740
2465
|
@decorator_knowngood
|
@@ -1867,35 +2592,25 @@ def triu_to_line(Q_list: List[Tensor]):
|
|
1867
2592
|
if q.dim() < 2:
|
1868
2593
|
out.append((None, q))
|
1869
2594
|
else:
|
1870
|
-
out.append((q.shape, q[tuple(torch.triu_indices(*q.shape))]))
|
2595
|
+
out.append((tuple(q.shape), q[tuple(torch.triu_indices(*q.shape))]))
|
1871
2596
|
return out
|
1872
2597
|
|
1873
2598
|
|
1874
|
-
|
1875
|
-
|
1876
|
-
assert n * (n + 1) == 2 * numel
|
1877
|
-
return n, n
|
1878
|
-
|
1879
|
-
|
1880
|
-
@decorator
|
1881
|
-
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
|
2599
|
+
@decorator_knowngood
|
2600
|
+
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
|
1882
2601
|
new = []
|
1883
2602
|
for shape, q in Q_list:
|
1884
2603
|
if shape is not None:
|
1885
|
-
|
1886
|
-
|
1887
|
-
x
|
1888
|
-
|
2604
|
+
x, y = torch.triu_indices(*shape, device=q.device)
|
2605
|
+
q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
|
2606
|
+
q_mat[x, y] = q
|
2607
|
+
if symmetric_output:
|
2608
|
+
q_mat[y, x] = q
|
2609
|
+
q = q_mat
|
1889
2610
|
new.append(q)
|
1890
2611
|
return new
|
1891
2612
|
|
1892
2613
|
|
1893
|
-
def update_triu_(q_state, materialised):
|
1894
|
-
for (shape0, q), (shape1, m) in zip(q_state, triu_to_line(materialised)):
|
1895
|
-
assert shape0 == shape1
|
1896
|
-
copy_stochastic_(q, m)
|
1897
|
-
|
1898
|
-
|
1899
2614
|
_warned = set()
|
1900
2615
|
|
1901
2616
|
|
@@ -1918,52 +2633,118 @@ def psgd_should_update(
|
|
1918
2633
|
return int(group[name]) > int(cumulative_prob)
|
1919
2634
|
|
1920
2635
|
|
2636
|
+
@functools.lru_cache(maxsize=None)
|
2637
|
+
def cached_precond_grad_expr(Q_dim, grad_dim):
|
2638
|
+
expr = [f"{c.upper()}{c}" if q_ == 2 else c for c, q_ in zip(einsum_base, Q_dim)]
|
2639
|
+
expr = ",".join(expr)
|
2640
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2641
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2642
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
2643
|
+
|
2644
|
+
|
1921
2645
|
@decorator_knowngood
|
1922
2646
|
def precond_grad_cached_(
|
1923
|
-
|
2647
|
+
ea: Tensor,
|
2648
|
+
cached_q: List[Tensor],
|
2649
|
+
caution: bool = False,
|
2650
|
+
grad: Optional[Tensor] = None,
|
2651
|
+
cast: bool = True,
|
1924
2652
|
):
|
1925
2653
|
if caution:
|
1926
2654
|
ea = _compilable_cautioning(grad, ea)
|
1927
2655
|
md = min_dtype(list(cached_q) + [ea])
|
1928
2656
|
args = [q.to(md) for q in cached_q]
|
1929
2657
|
args = args + [ea.to(md)]
|
1930
|
-
|
2658
|
+
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
|
2659
|
+
new = compiled_einsum(expr, *args)
|
1931
2660
|
if cast:
|
1932
2661
|
return new.to(ea.dtype)
|
1933
2662
|
return new
|
1934
2663
|
|
1935
2664
|
|
2665
|
+
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
|
2666
|
+
|
2667
|
+
|
1936
2668
|
@decorator_knowngood
|
1937
|
-
def _compilable_fused_precond_grad_cached_(
|
1938
|
-
precond = precond_grad_cached_(
|
2669
|
+
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2670
|
+
precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
|
1939
2671
|
update_param_(param, precond, lr, decay, caution=False)
|
1940
2672
|
|
1941
2673
|
|
1942
|
-
def fused_precond_grad_cached_(
|
1943
|
-
lr = scalar_guard(lr, param[0])
|
1944
|
-
_compilable_fused_precond_grad_cached_(
|
2674
|
+
def fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
|
2675
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2676
|
+
_compilable_fused_precond_grad_cached_(ea, param, lr, grad, decay, caution, cached_q)
|
2677
|
+
|
2678
|
+
|
2679
|
+
@functools.lru_cache(maxsize=None)
|
2680
|
+
def precond_grad_expr(Q_dim, grad_dim):
|
2681
|
+
expr = [
|
2682
|
+
f"{c2}{c.upper()},{c2}{c}" if q_ == 2 else f"{c},{c}" for c, c2, q_ in zip(einsum_base, einsum_base[13:], Q_dim)
|
2683
|
+
]
|
2684
|
+
expr = ",".join(expr)
|
2685
|
+
grad_expr = "".join(c for c, _ in zip(einsum_base, range(grad_dim)))
|
2686
|
+
out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
2687
|
+
return f"{expr},{grad_expr}->{out_expr}"
|
1945
2688
|
|
1946
2689
|
|
1947
2690
|
@decorator_knowngood
|
1948
|
-
def psgd_precond_grad(
|
2691
|
+
def psgd_precond_grad(
|
2692
|
+
ea: Tensor,
|
2693
|
+
preconds: TriuOrLine,
|
2694
|
+
caution: bool = False,
|
2695
|
+
grad: Optional[Tensor] = None,
|
2696
|
+
store_triu_as_line: bool = False,
|
2697
|
+
symmetric_output: bool = False,
|
2698
|
+
):
|
1949
2699
|
if caution:
|
1950
2700
|
ea = _compilable_cautioning(grad, ea)
|
2701
|
+
if store_triu_as_line:
|
2702
|
+
preconds = line_to_triu(preconds, symmetric_output)
|
1951
2703
|
md = min_dtype(list(preconds) + [ea])
|
1952
2704
|
args = [q.to(md) for q in preconds]
|
1953
|
-
|
1954
|
-
new =
|
2705
|
+
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
|
2706
|
+
new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
|
1955
2707
|
return new.to(ea.dtype)
|
1956
2708
|
|
1957
2709
|
|
1958
2710
|
@decorator_knowngood
|
1959
|
-
def _compilable_fused_psgd_precond_grad(
|
1960
|
-
|
2711
|
+
def _compilable_fused_psgd_precond_grad(
|
2712
|
+
ea: Tensor,
|
2713
|
+
param,
|
2714
|
+
lr,
|
2715
|
+
grad,
|
2716
|
+
decay,
|
2717
|
+
caution,
|
2718
|
+
preconds: TriuOrLine,
|
2719
|
+
store_triu_as_line: bool = False,
|
2720
|
+
symmetric_output: bool = False,
|
2721
|
+
):
|
2722
|
+
precond = psgd_precond_grad(
|
2723
|
+
ea,
|
2724
|
+
preconds,
|
2725
|
+
caution=caution,
|
2726
|
+
grad=grad,
|
2727
|
+
store_triu_as_line=store_triu_as_line,
|
2728
|
+
symmetric_output=symmetric_output,
|
2729
|
+
)
|
1961
2730
|
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
1962
2731
|
|
1963
2732
|
|
1964
|
-
def fused_psgd_precond_grad(
|
1965
|
-
|
1966
|
-
|
2733
|
+
def fused_psgd_precond_grad(
|
2734
|
+
ea: Tensor,
|
2735
|
+
param,
|
2736
|
+
lr,
|
2737
|
+
grad,
|
2738
|
+
decay,
|
2739
|
+
caution,
|
2740
|
+
preconds: TriuOrLine,
|
2741
|
+
store_triu_as_line: bool = False,
|
2742
|
+
symmetric_output: bool = False,
|
2743
|
+
):
|
2744
|
+
lr, decay = scalar_guard(lr, decay, param[0])
|
2745
|
+
_compilable_fused_psgd_precond_grad(
|
2746
|
+
ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
|
2747
|
+
)
|
1967
2748
|
|
1968
2749
|
|
1969
2750
|
@decorator_knowngood
|
@@ -2015,7 +2796,15 @@ def caution(g, update):
|
|
2015
2796
|
return _compilable_cautioning(g, update)
|
2016
2797
|
|
2017
2798
|
|
2018
|
-
def
|
2799
|
+
def _inner_precond_update_prob_schedule(
|
2800
|
+
n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2801
|
+
):
|
2802
|
+
return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
|
2803
|
+
|
2804
|
+
|
2805
|
+
def precond_update_prob_schedule(
|
2806
|
+
max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
|
2807
|
+
):
|
2019
2808
|
"""Anneal preconditioner update probability during beginning of training.
|
2020
2809
|
|
2021
2810
|
PSGD benefits from more preconditioner updates at the beginning of training,
|
@@ -2026,11 +2815,9 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
|
|
2026
2815
|
`min_prob` by ~4000 steps. Default settings work very well for most models and
|
2027
2816
|
training regimes.
|
2028
2817
|
"""
|
2029
|
-
|
2030
|
-
|
2031
|
-
|
2032
|
-
|
2033
|
-
return _schedule
|
2818
|
+
return functools.partial(
|
2819
|
+
_inner_precond_update_prob_schedule, max_prob=max_prob, min_prob=min_prob, decay=decay, flat_start=flat_start
|
2820
|
+
)
|
2034
2821
|
|
2035
2822
|
|
2036
2823
|
def merge_group(group, *tensors):
|
@@ -2164,3 +2951,16 @@ def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
|
2164
2951
|
def disable_caution_scaling():
|
2165
2952
|
global _compilable_cautioning
|
2166
2953
|
_compilable_cautioning = _compilable_caution_no_scale
|
2954
|
+
|
2955
|
+
|
2956
|
+
@decorator_knowngood
|
2957
|
+
def sam_step(parameters, ball_size, adaptive: bool = True):
|
2958
|
+
old_params = []
|
2959
|
+
for p in parameters:
|
2960
|
+
old_params.append(p.detach().clone())
|
2961
|
+
grad = promote(p.grad)
|
2962
|
+
if adaptive:
|
2963
|
+
grad = grad * promote(p).square()
|
2964
|
+
stochastic_add_(p.data, grad, ball_size)
|
2965
|
+
p.grad.zero_()
|
2966
|
+
return old_params
|