heavyball 0.25.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +207 -40
- heavyball/chainable.py +532 -0
- heavyball/utils.py +409 -231
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/METADATA +6 -5
- heavyball-1.1.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -93
- heavyball/foreach_solp.py +0 -89
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -101
- heavyball/palm_foreach_solp.py +0 -98
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1.dist-info/RECORD +0 -28
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
|
4
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
5
|
+
Modified under Creative Commons Attribution 4.0 International
|
6
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
7
|
+
"""
|
8
|
+
|
1
9
|
import functools
|
2
10
|
import gc
|
3
11
|
import math
|
@@ -15,7 +23,8 @@ from torch.utils._pytree import tree_map
|
|
15
23
|
compile_mode = "max-autotune-no-cudagraphs"
|
16
24
|
dynamic = False
|
17
25
|
compile_mode_recommended_to_none = None
|
18
|
-
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
26
|
+
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
27
|
+
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
19
28
|
|
20
29
|
|
21
30
|
def decorator(func):
|
@@ -60,30 +69,22 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
60
69
|
|
61
70
|
@decorator_knowngood
|
62
71
|
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
|
63
|
-
beta1: Tensor):
|
72
|
+
beta1: Tensor, decay: float):
|
73
|
+
grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
|
64
74
|
p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
|
65
75
|
for p_, z_, g_ in zip(p32, z32, g32):
|
76
|
+
if decay != 0:
|
77
|
+
g_.add_(p_, alpha=decay)
|
66
78
|
p_.lerp_(z_, ckp1)
|
67
|
-
p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1)
|
68
|
-
z_.add_(g_, alpha
|
79
|
+
p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
|
80
|
+
z_.add_(g_, alpha=lr)
|
69
81
|
copy_stochastic_list_(p, p32)
|
70
82
|
copy_stochastic_list_(z, z32)
|
71
83
|
|
72
84
|
|
73
|
-
def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
74
|
-
weight = lr ** weight_lr_power * max(step, 1) ** r
|
75
|
-
weight_sum = weight_sum + weight
|
76
|
-
|
77
|
-
try:
|
78
|
-
ckp1 = weight / weight_sum
|
79
|
-
except ZeroDivisionError:
|
80
|
-
ckp1 = 0
|
81
|
-
return ckp1, weight_sum
|
82
|
-
|
83
|
-
|
84
85
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
85
|
-
z: List[Tensor], grad:
|
86
|
-
weight = lr ** weight_lr_power * max(step, 1) ** r
|
86
|
+
z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
|
87
|
+
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
87
88
|
weight_sum = weight_sum + weight
|
88
89
|
|
89
90
|
try:
|
@@ -91,10 +92,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
91
92
|
except ZeroDivisionError:
|
92
93
|
ckp1 = 0
|
93
94
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
|
95
|
+
grad, parameters, z = list_guard(grad, parameters, z)
|
96
|
+
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
97
|
+
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
|
98
98
|
return weight_sum
|
99
99
|
|
100
100
|
|
@@ -162,10 +162,13 @@ def beta_debias(beta, step):
|
|
162
162
|
@decorator_knowngood
|
163
163
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
164
164
|
out: List[Optional[Tensor]]):
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
165
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
166
|
+
torch._foreach_mul_(s32, beta2)
|
167
|
+
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
168
|
+
denom = torch._foreach_sqrt(s32)
|
169
|
+
[d.clamp_(min=eps) for d in denom]
|
170
|
+
copy_stochastic_list_(state, s32)
|
171
|
+
|
169
172
|
if out[0] is None:
|
170
173
|
return denom
|
171
174
|
|
@@ -174,15 +177,32 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
|
|
174
177
|
|
175
178
|
|
176
179
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
177
|
-
state, grad, out = list_guard(state
|
178
|
-
beta2, eps = scalar_guard(beta2,
|
180
|
+
state, grad, out = list_guard(state, grad, out)
|
181
|
+
beta2, eps = scalar_guard(beta2, eps, state[0])
|
179
182
|
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
180
183
|
|
181
184
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
185
|
+
@decorator_knowngood
|
186
|
+
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
187
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
188
|
+
torch._foreach_mul_(s32, beta2)
|
189
|
+
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
190
|
+
denom = torch._foreach_sqrt(s32)
|
191
|
+
[d.clamp_(min=eps) for d in denom]
|
192
|
+
out = torch._foreach_div_(g32, denom)
|
193
|
+
copy_stochastic_list_(state, s32)
|
194
|
+
copy_stochastic_list_(grad, out)
|
195
|
+
|
196
|
+
|
197
|
+
def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
198
|
+
grad, exp_avg_sq = list_guard(grad, exp_avg_sq)
|
199
|
+
beta2, eps = scalar_guard(beta2, eps, grad[0])
|
200
|
+
_compilable_scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps)
|
201
|
+
return grad
|
202
|
+
|
203
|
+
|
204
|
+
@decorator_knowngood
|
205
|
+
def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
|
186
206
|
p_norm = torch._foreach_norm(parameters)
|
187
207
|
g_norm = torch._foreach_norm(gradients)
|
188
208
|
torch._foreach_maximum_(p_norm, minimum)
|
@@ -190,7 +210,16 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
|
|
190
210
|
torch._foreach_div_(p_norm, g_norm)
|
191
211
|
torch._foreach_mul_(p_norm, clip_val)
|
192
212
|
torch._foreach_minimum_(p_norm, 1)
|
193
|
-
torch.
|
213
|
+
return torch._foreach_mul(gradients, p_norm)
|
214
|
+
|
215
|
+
|
216
|
+
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
217
|
+
minimum: float = 1e-3, eps: float = 1e-8):
|
218
|
+
if clip_val <= 0:
|
219
|
+
return gradients
|
220
|
+
parameters, gradients = list_guard(parameters, gradients)
|
221
|
+
clip_val = scalar_guard(clip_val, parameters[0])
|
222
|
+
return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
|
194
223
|
|
195
224
|
|
196
225
|
def is_compiling():
|
@@ -205,10 +234,7 @@ def set_(dst: Tensor, src: Tensor):
|
|
205
234
|
return
|
206
235
|
if src.shape != dst.shape:
|
207
236
|
src = src.reshape_as(dst)
|
208
|
-
|
209
|
-
dst.set_(src)
|
210
|
-
else:
|
211
|
-
dst.copy_(src)
|
237
|
+
dst.copy_(src)
|
212
238
|
|
213
239
|
|
214
240
|
def clean():
|
@@ -226,33 +252,29 @@ def set_torch():
|
|
226
252
|
|
227
253
|
|
228
254
|
@decorator
|
229
|
-
def zeropower_via_newtonschulz5(G,
|
255
|
+
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
230
256
|
"""
|
231
|
-
Modified from "modded-nanogpt" under the MIT license:
|
232
|
-
Original: https://github.com/KellerJordan/modded-nanogpt/blob/a0dcbfdd9a0617d091d5123cfc354745428e40d3/train_gpt2.py
|
233
|
-
|
234
257
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
235
258
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
236
259
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
237
260
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
238
261
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
239
|
-
where S' is diagonal with S_{ii}'
|
262
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
240
263
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
241
264
|
"""
|
242
265
|
assert len(G.shape) == 2
|
243
266
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
244
|
-
X = G.
|
245
|
-
|
246
|
-
X = X / (X.norm() + eps) # ensure top singular value <= 1
|
267
|
+
X = G.bfloat16()
|
268
|
+
X /= (X.norm() + eps) # ensure top singular value <= 1
|
247
269
|
if G.size(0) > G.size(1):
|
248
270
|
X = X.T
|
249
271
|
for _ in range(steps):
|
250
|
-
A = X @ X.T
|
251
|
-
B = A @
|
252
|
-
|
272
|
+
A = X @ X.T
|
273
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
274
|
+
X = a * X + B @ X
|
253
275
|
if G.size(0) > G.size(1):
|
254
276
|
X = X.T
|
255
|
-
return X
|
277
|
+
return X.to(G.dtype)
|
256
278
|
|
257
279
|
|
258
280
|
def ortho(x):
|
@@ -264,6 +286,53 @@ def ortho(x):
|
|
264
286
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
|
265
287
|
|
266
288
|
|
289
|
+
@decorator_knowngood
|
290
|
+
def _compilable_heavyball_momentum_(state, grad, beta):
|
291
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
292
|
+
torch._foreach_mul_(s32, beta)
|
293
|
+
torch._foreach_add_(s32, g32)
|
294
|
+
copy_stochastic_list_(state, s32)
|
295
|
+
copy_stochastic_list_(grad, s32)
|
296
|
+
|
297
|
+
|
298
|
+
@decorator_knowngood
|
299
|
+
def _compilable_nesterov_momentum_(state, grad, beta):
|
300
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
301
|
+
torch._foreach_mul_(s32, beta)
|
302
|
+
torch._foreach_add_(s32, g32)
|
303
|
+
[g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
|
304
|
+
copy_stochastic_list_(state, s32)
|
305
|
+
copy_stochastic_list_(grad, g32)
|
306
|
+
|
307
|
+
|
308
|
+
def heavyball_momentum(state, grad, beta):
|
309
|
+
state, grad = list_guard(state, grad)
|
310
|
+
beta = scalar_guard(beta, state[0])
|
311
|
+
_compilable_heavyball_momentum_(state, grad, beta)
|
312
|
+
return grad
|
313
|
+
|
314
|
+
|
315
|
+
def nesterov_momentum(state, grad, beta):
|
316
|
+
state, grad = list_guard(state, grad)
|
317
|
+
beta = scalar_guard(beta, state[0])
|
318
|
+
_compilable_nesterov_momentum_(state, grad, beta)
|
319
|
+
return grad
|
320
|
+
|
321
|
+
|
322
|
+
@decorator_knowngood
|
323
|
+
def inplace_orthogonal_(x, mode, out):
|
324
|
+
if mode == 'qr':
|
325
|
+
y = torch.linalg.qr(x).Q
|
326
|
+
elif mode == 'svd':
|
327
|
+
u, s, v = torch.linalg.svd(x)
|
328
|
+
y = u @ v.T
|
329
|
+
elif mode == 'newtonschulz':
|
330
|
+
y = zeropower_via_newtonschulz5(x, 5)
|
331
|
+
else:
|
332
|
+
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
333
|
+
set_(out, y)
|
334
|
+
|
335
|
+
|
267
336
|
def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
268
337
|
"""
|
269
338
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -294,17 +363,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
294
363
|
est_eig = torch.einsum('ij,ij->j', o, tmp)
|
295
364
|
sort_idx = torch.argsort(est_eig, descending=True)
|
296
365
|
indices.append(sort_idx)
|
297
|
-
|
298
|
-
set_(q, torch.linalg.eigh(m)[1])
|
299
|
-
elif zeroth_power_mode.startswith('newtonschulz'):
|
300
|
-
iterations = zeroth_power_mode[len('newtonschulz'):]
|
301
|
-
if iterations == '':
|
302
|
-
iterations = 10
|
303
|
-
else:
|
304
|
-
iterations = int(iterations)
|
305
|
-
set_(q, zeropower_via_newtonschulz5(m, o[:, sort_idx], iterations))
|
306
|
-
else:
|
307
|
-
set_(q, ortho(tmp[:, sort_idx]))
|
366
|
+
inplace_orthogonal_(tmp[:, sort_idx], q)
|
308
367
|
|
309
368
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
310
369
|
for i, ind in enumerate(indices))
|
@@ -353,8 +412,6 @@ def get_orthogonal_matrix(mat):
|
|
353
412
|
|
354
413
|
Q = torch.flip(Q, [1])
|
355
414
|
|
356
|
-
if not float_data:
|
357
|
-
Q = Q.to(original_device).type(original_type)
|
358
415
|
final.append(Q)
|
359
416
|
|
360
417
|
return final
|
@@ -369,24 +426,57 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
369
426
|
copy_stochastic_(x_, x32)
|
370
427
|
|
371
428
|
|
429
|
+
def get_beta1(group):
|
430
|
+
beta = None
|
431
|
+
if 'beta' in group:
|
432
|
+
beta = group['beta']
|
433
|
+
if beta is None and 'betas' in group:
|
434
|
+
beta = group['betas'][0]
|
435
|
+
if beta is None:
|
436
|
+
raise ValueError("Beta not found in group.")
|
437
|
+
return beta
|
438
|
+
|
439
|
+
|
440
|
+
def get_beta2(group):
|
441
|
+
if 'beta2_scale' in group:
|
442
|
+
step = max(group.get("step", 1), 1)
|
443
|
+
return 1 - step ** -group['beta2_scale']
|
444
|
+
if 'betas' in group:
|
445
|
+
return group['betas'][1]
|
446
|
+
raise ValueError("Beta2 not found in group.")
|
447
|
+
|
448
|
+
|
372
449
|
def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
373
|
-
x, y = list_guard(x
|
450
|
+
x, y = list_guard(x, y)
|
374
451
|
a = scalar_guard(a, x[0])
|
375
452
|
_compilable_stochastic_lerp_(x, y, a)
|
376
453
|
|
377
454
|
|
378
|
-
def list_guard(
|
379
|
-
|
380
|
-
|
381
|
-
|
455
|
+
def list_guard(*xs):
|
456
|
+
out = []
|
457
|
+
for x in xs:
|
458
|
+
if isinstance(x, (list, tuple)):
|
459
|
+
out.append(x)
|
460
|
+
else:
|
461
|
+
out.append([x])
|
462
|
+
if len(xs) == 1:
|
463
|
+
return out[0]
|
464
|
+
return out
|
382
465
|
|
383
466
|
|
384
|
-
def scalar_guard(
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
467
|
+
def scalar_guard(*args):
|
468
|
+
*xs, ref = args
|
469
|
+
out = []
|
470
|
+
for x in xs:
|
471
|
+
if isinstance(x, float):
|
472
|
+
out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
|
473
|
+
elif isinstance(x, int):
|
474
|
+
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
|
475
|
+
else:
|
476
|
+
out.append(x)
|
477
|
+
if len(xs) == 1:
|
478
|
+
return out[0]
|
479
|
+
return out
|
390
480
|
|
391
481
|
|
392
482
|
@decorator_knowngood
|
@@ -399,7 +489,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
399
489
|
|
400
490
|
|
401
491
|
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
402
|
-
x, y = list_guard(x
|
492
|
+
x, y = list_guard(x, y)
|
403
493
|
alpha = scalar_guard(alpha, x[0])
|
404
494
|
_compilable_stochastic_add_(x, y, alpha)
|
405
495
|
|
@@ -435,35 +525,35 @@ def min_dtype(xs: List[Tensor]):
|
|
435
525
|
return torch.float32
|
436
526
|
|
437
527
|
|
438
|
-
def update_preconditioner(grad,
|
528
|
+
def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
|
439
529
|
"""
|
440
530
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
441
531
|
"""
|
442
|
-
compute_ggt(grad,
|
443
|
-
if state['Q'] is None:
|
444
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
532
|
+
compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
|
445
533
|
if update_precond:
|
446
|
-
get_orthogonal_matrix_QR(
|
534
|
+
get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
|
447
535
|
|
448
536
|
|
449
|
-
def init_preconditioner(grad, state, max_precond_dim=10000, precondition_1d=False):
|
537
|
+
def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
|
450
538
|
"""
|
451
539
|
Initializes the preconditioner matrices (L and R in the paper).
|
452
540
|
"""
|
453
|
-
state['Q'] = None # Will hold all the eigenbases of the preconditioner.
|
454
541
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
455
542
|
if grad.dim() == 1:
|
456
|
-
if
|
543
|
+
if precondition_1d or grad.shape[0] > max_precond_dim:
|
544
|
+
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
545
|
+
else:
|
457
546
|
state['GG'].append([])
|
458
|
-
return
|
459
|
-
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
460
|
-
return
|
461
547
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
548
|
+
else:
|
549
|
+
for sh in grad.shape:
|
550
|
+
if sh > max_precond_dim:
|
551
|
+
state['GG'].append([])
|
552
|
+
else:
|
553
|
+
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
554
|
+
|
555
|
+
compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
|
556
|
+
state['Q'] = get_orthogonal_matrix(state['GG'])
|
467
557
|
|
468
558
|
|
469
559
|
@decorator
|
@@ -629,74 +719,87 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
629
719
|
return loss
|
630
720
|
|
631
721
|
|
632
|
-
|
633
|
-
class ScheduleFree(StatefulOptimizer):
|
634
|
-
def eval(self):
|
635
|
-
for group in self.param_groups:
|
636
|
-
train_mode = group['train_mode']
|
637
|
-
beta1 = group['beta'] if 'beta' in group else group['betas'][0]
|
638
|
-
if beta1 > 0 and train_mode:
|
639
|
-
for p in group['params']:
|
640
|
-
state = self.state_(p)
|
641
|
-
if 'z' in state:
|
642
|
-
# Set p.data to x
|
643
|
-
z = promote(state['z'])
|
644
|
-
p32 = promote(p.data)
|
645
|
-
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
646
|
-
copy_stochastic_(p.data, p32)
|
647
|
-
group['train_mode'] = False
|
648
|
-
|
649
|
-
def train(self):
|
650
|
-
for group in self.param_groups:
|
651
|
-
train_mode = group['train_mode']
|
652
|
-
beta1 = group['beta'] if 'beta' in group else group['betas'][0]
|
653
|
-
if beta1 > 0 and not train_mode:
|
654
|
-
for p in group['params']:
|
655
|
-
state = self.state_(p)
|
656
|
-
if 'z' in state:
|
657
|
-
z = promote(state['z'])
|
658
|
-
p32 = promote(p.data)
|
659
|
-
p32.lerp_(end=z, weight=1 - beta1)
|
660
|
-
copy_stochastic_(p.data, p32)
|
661
|
-
group['train_mode'] = True
|
662
|
-
|
663
|
-
def _step(self):
|
664
|
-
raise NotImplementedError
|
665
|
-
|
666
|
-
|
667
722
|
def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
668
723
|
for t, s in zip(target, source):
|
669
724
|
copy_stochastic_(t, s)
|
670
725
|
|
671
726
|
|
672
727
|
@decorator_knowngood
|
673
|
-
def
|
674
|
-
|
728
|
+
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
729
|
+
step: Tensor):
|
675
730
|
beta1 = beta_debias(beta1, step)
|
676
731
|
beta2 = beta_debias(beta2, step)
|
677
732
|
|
678
|
-
g32,
|
733
|
+
g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
|
679
734
|
|
680
|
-
|
681
|
-
denom = exp_avg_sq_(exp_avg_sq32,
|
735
|
+
[ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
|
736
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
|
737
|
+
u32 = torch._foreach_div(exp_avg32, denom)
|
738
|
+
|
739
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
740
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
741
|
+
copy_stochastic_list_(grad, u32)
|
742
|
+
|
743
|
+
|
744
|
+
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
745
|
+
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
746
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
747
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
748
|
+
return grad
|
682
749
|
|
750
|
+
|
751
|
+
@decorator_knowngood
|
752
|
+
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
753
|
+
beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
|
754
|
+
caution: bool):
|
755
|
+
beta1 = beta_debias(beta1, step)
|
756
|
+
beta2 = beta_debias(beta2, step)
|
757
|
+
|
758
|
+
g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
|
759
|
+
|
760
|
+
[ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
|
761
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
|
762
|
+
u32 = torch._foreach_div(exp_avg32, denom)
|
763
|
+
|
764
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
683
765
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
684
|
-
|
766
|
+
_compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
|
767
|
+
|
768
|
+
|
769
|
+
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
770
|
+
beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
|
771
|
+
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
772
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
773
|
+
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
|
774
|
+
|
685
775
|
|
776
|
+
@decorator_knowngood
|
777
|
+
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
778
|
+
beta2: Tensor, step: Tensor):
|
779
|
+
beta1 = beta_debias(beta1, step)
|
780
|
+
beta2 = beta_debias(beta2, step)
|
781
|
+
|
782
|
+
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
|
783
|
+
|
784
|
+
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
785
|
+
gp32 = torch._foreach_div(gp32, denom)
|
786
|
+
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
787
|
+
|
788
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
789
|
+
copy_stochastic_list_(grad, exp_avg)
|
686
790
|
|
687
|
-
def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
|
688
|
-
beta1: float, beta2: float, step: int):
|
689
|
-
exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
|
690
|
-
grad), list_guard(grad_projected)
|
691
|
-
beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
|
692
|
-
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
693
|
-
return denom
|
694
791
|
|
792
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
793
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
794
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
795
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
796
|
+
return grad
|
695
797
|
|
696
798
|
|
697
799
|
@decorator_knowngood
|
698
|
-
def
|
699
|
-
|
800
|
+
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
801
|
+
grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
|
802
|
+
decay: Tensor, caution: bool):
|
700
803
|
beta1 = beta_debias(beta1, step)
|
701
804
|
beta2 = beta_debias(beta2, step)
|
702
805
|
|
@@ -705,31 +808,89 @@ def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
|
705
808
|
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
706
809
|
gp32 = torch._foreach_div(gp32, denom)
|
707
810
|
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
811
|
+
update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
|
708
812
|
|
709
813
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
710
814
|
|
711
815
|
|
712
|
-
def
|
713
|
-
|
714
|
-
exp_avg, exp_avg_sq,
|
715
|
-
beta1,
|
716
|
-
|
816
|
+
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
817
|
+
beta2: float, step: int, lr: float, decay: float, caution: bool):
|
818
|
+
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
819
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
820
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
|
717
821
|
|
718
822
|
|
719
823
|
@decorator_knowngood
|
720
|
-
def
|
721
|
-
|
722
|
-
|
723
|
-
|
824
|
+
def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
825
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
826
|
+
update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
|
827
|
+
|
828
|
+
beta1 = beta_debias(beta1, step)
|
829
|
+
denom = torch._foreach_sqrt(exp_avg_sq32)
|
830
|
+
[denom.clamp_(min=eps) for denom in denom]
|
831
|
+
torch._foreach_mul_(exp_avg32, beta1)
|
832
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
833
|
+
|
834
|
+
beta2 = beta_debias(beta2, step + 1)
|
835
|
+
torch._foreach_mul_(exp_avg_sq32, beta2)
|
836
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
837
|
+
|
838
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
839
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
840
|
+
|
841
|
+
|
842
|
+
def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
843
|
+
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
844
|
+
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
845
|
+
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
846
|
+
|
847
|
+
|
848
|
+
@decorator_knowngood
|
849
|
+
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
850
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
851
|
+
update = [e.clone() for e in exp_avg]
|
852
|
+
|
853
|
+
beta1 = beta_debias(beta1, step)
|
854
|
+
denom = torch._foreach_sqrt(exp_avg_sq32)
|
855
|
+
[denom.clamp_(min=1e-8) for denom in denom]
|
856
|
+
torch._foreach_mul_(exp_avg32, beta1)
|
857
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
858
|
+
|
859
|
+
beta2 = beta_debias(beta2, step + 1)
|
860
|
+
torch._foreach_mul_(exp_avg_sq32, beta2)
|
861
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
862
|
+
|
863
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
864
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
865
|
+
copy_stochastic_list_(grad, update)
|
724
866
|
|
725
|
-
# add the random number to the lower 16 bit of the mantissa
|
726
|
-
result.add_(source.view(dtype=torch.int32))
|
727
867
|
|
728
|
-
|
868
|
+
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
869
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
|
870
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
871
|
+
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
872
|
+
return grad
|
873
|
+
|
874
|
+
|
875
|
+
def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
876
|
+
return [stochastic_round_(r, s) for r, s in zip(ref, source)]
|
877
|
+
|
878
|
+
|
879
|
+
@decorator_knowngood
|
880
|
+
def stochastic_round_(ref: Tensor, source: Tensor):
|
881
|
+
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
882
|
+
return source
|
883
|
+
if ref.dtype != torch.bfloat16:
|
884
|
+
return source.to(ref.dtype)
|
885
|
+
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
886
|
+
result.add_(source.view(dtype=torch.int32))
|
729
887
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
888
|
+
return result.view(dtype=torch.float32).bfloat16()
|
730
889
|
|
731
|
-
|
732
|
-
|
890
|
+
|
891
|
+
@decorator_knowngood
|
892
|
+
def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
893
|
+
target.copy_(stochastic_round_(target, source))
|
733
894
|
|
734
895
|
|
735
896
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
@@ -759,7 +920,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
|
|
759
920
|
|
760
921
|
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
|
761
922
|
caution: bool = False, grad: List[Tensor] = None):
|
762
|
-
param, update, grad = list_guard(param
|
923
|
+
param, update, grad = list_guard(param, update, grad)
|
763
924
|
lr = scalar_guard(lr, param[0])
|
764
925
|
if not caution:
|
765
926
|
grad = [None] * len(param)
|
@@ -865,11 +1026,15 @@ def psgd_balance_Q(Q_in):
|
|
865
1026
|
|
866
1027
|
|
867
1028
|
def psgd_calc_A_and_conjB(exprA, G, Q):
|
1029
|
+
V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
|
1030
|
+
eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
|
1031
|
+
eps *= G.norm() / G.numel()
|
1032
|
+
G += V * eps
|
868
1033
|
md = min_dtype(Q + [G])
|
869
1034
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
870
1035
|
order = G.dim()
|
871
1036
|
p = list(range(order))
|
872
|
-
conjB = torch.
|
1037
|
+
conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
|
873
1038
|
Q = [promote(q) for q in Q]
|
874
1039
|
for i, q in enumerate(Q):
|
875
1040
|
if q.dim() <= 1:
|
@@ -902,7 +1067,7 @@ def psgd_lb(A, max_abs):
|
|
902
1067
|
|
903
1068
|
|
904
1069
|
@decorator
|
905
|
-
def psgd_update_precond(Q, exprs, G, precond_lr,
|
1070
|
+
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
906
1071
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
907
1072
|
exprA, exprGs, _ = exprs
|
908
1073
|
|
@@ -923,10 +1088,10 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
923
1088
|
norm = term2.norm(float('inf'))
|
924
1089
|
if q.dim() < 2:
|
925
1090
|
term1 *= q.to(term1.dtype)
|
926
|
-
term1 /= norm.clamp_(min=
|
1091
|
+
term1 /= norm.clamp_(min=tiny_bf16)
|
927
1092
|
else:
|
928
1093
|
torch.triu(term1, out=term1)
|
929
|
-
term1 /= psgd_lb(term2, norm).clamp_(
|
1094
|
+
term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
|
930
1095
|
torch.matmul(term1, q, out=term1)
|
931
1096
|
if store_triu_as_line:
|
932
1097
|
term1 = triu_to_line([term1])[0][1]
|
@@ -935,22 +1100,32 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
935
1100
|
|
936
1101
|
|
937
1102
|
@decorator_knowngood
|
938
|
-
def
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
1103
|
+
def _compilable_l2_clip_(x):
|
1104
|
+
ref = x
|
1105
|
+
x = list(map(promote, x))
|
1106
|
+
norm = torch._foreach_norm(x)
|
1107
|
+
torch._foreach_maximum_(norm, 1e-8)
|
1108
|
+
out = torch._foreach_div(x, norm)
|
1109
|
+
return stochastic_round_list_(ref, out)
|
1110
|
+
|
946
1111
|
|
1112
|
+
def l2_clip_(x):
|
1113
|
+
x = list_guard(x)
|
1114
|
+
return _compilable_l2_clip_(x)
|
947
1115
|
|
948
|
-
|
1116
|
+
|
1117
|
+
@decorator_knowngood
|
1118
|
+
def _compilable_rmsnorm_clip_(x):
|
1119
|
+
x = list(map(promote, x))
|
949
1120
|
norm = torch._foreach_norm(x)
|
950
|
-
|
951
|
-
|
952
|
-
torch.
|
953
|
-
|
1121
|
+
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1122
|
+
torch._foreach_maximum_(norm, 1e-6)
|
1123
|
+
return torch._foreach_div(x, norm)
|
1124
|
+
|
1125
|
+
|
1126
|
+
def rmsnorm_clip_(x):
|
1127
|
+
x = list_guard(x)
|
1128
|
+
return _compilable_rmsnorm_clip_(x)
|
954
1129
|
|
955
1130
|
|
956
1131
|
def mu_law_compress(x, mu=127.0):
|
@@ -990,18 +1165,24 @@ def identity(x):
|
|
990
1165
|
return x
|
991
1166
|
|
992
1167
|
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
torch.
|
998
|
-
|
999
|
-
torch.
|
1000
|
-
|
1168
|
+
@decorator_knowngood
|
1169
|
+
def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
1170
|
+
g32 = list(map(promote, grad))
|
1171
|
+
[g.mul_(1 / scale) for g in g32]
|
1172
|
+
tanh = torch._foreach_tanh(g32)
|
1173
|
+
torch._foreach_abs_(g32)
|
1174
|
+
torch._foreach_log1p_(g32)
|
1175
|
+
[g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
|
1001
1176
|
|
1002
|
-
torch._foreach_maximum_(
|
1003
|
-
torch._foreach_minimum_(
|
1004
|
-
return grad
|
1177
|
+
torch._foreach_maximum_(g32, -2)
|
1178
|
+
torch._foreach_minimum_(g32, 2)
|
1179
|
+
return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
|
1180
|
+
|
1181
|
+
|
1182
|
+
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1183
|
+
grad = list_guard(grad)
|
1184
|
+
lerp, scale = scalar_guard(lerp, scale, grad[0])
|
1185
|
+
return _compilable_trust_region_clip_(grad, lerp, scale)
|
1005
1186
|
|
1006
1187
|
|
1007
1188
|
@decorator
|
@@ -1040,60 +1221,57 @@ def update_triu_(q_state, materialised):
|
|
1040
1221
|
copy_stochastic_(q, m)
|
1041
1222
|
|
1042
1223
|
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
|
1059
|
-
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
1060
|
-
if prob is None:
|
1061
|
-
prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
|
1062
|
-
if group['stochastic_schedule']:
|
1063
|
-
return self.rng.random() < prob
|
1064
|
-
cumulative_prob = group.get(name, 0)
|
1065
|
-
group[name] = cumulative_prob + prob
|
1066
|
-
return int(group[name]) > int(cumulative_prob)
|
1067
|
-
|
1068
|
-
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
|
1069
|
-
for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
|
1070
|
-
psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
|
1071
|
-
|
1072
|
-
if self.should_update(group, self.balance_probability, "balance_prob"):
|
1073
|
-
for g, q in zip(grad_list, original_q if original_q else q_list):
|
1074
|
-
if g.dim() > 1:
|
1075
|
-
if store_triu_as_line:
|
1076
|
-
psgd_balance_Q([q_ for _, q_ in q])
|
1077
|
-
else:
|
1078
|
-
psgd_balance_Q(q)
|
1079
|
-
|
1080
|
-
|
1081
|
-
# TODO: Figure out why this sometimes crashes
|
1082
|
-
# @decorator_knowngood
|
1083
|
-
def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
|
1084
|
-
clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
|
1224
|
+
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1225
|
+
name: str = 'cumulative_prob'):
|
1226
|
+
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
1227
|
+
if not isinstance(prob, float):
|
1228
|
+
prob = prob(group[f'{name}_prob_step'])
|
1229
|
+
if group['stochastic_schedule']:
|
1230
|
+
return rng.random() < prob
|
1231
|
+
cumulative_prob = state.get(name, 0)
|
1232
|
+
group[name] = cumulative_prob + prob
|
1233
|
+
return int(group[name]) > int(cumulative_prob)
|
1234
|
+
|
1235
|
+
|
1236
|
+
@decorator_knowngood
|
1237
|
+
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
|
1085
1238
|
md = min_dtype(list(cached_q) + [ea])
|
1086
1239
|
args = [q.to(md) for q in cached_q]
|
1087
1240
|
args = args + [ea.to(md)]
|
1088
1241
|
new = torch.einsum(expr, *args)
|
1089
|
-
|
1090
|
-
|
1242
|
+
if cast:
|
1243
|
+
return new.to(ea.dtype)
|
1244
|
+
return new
|
1245
|
+
|
1091
1246
|
|
1247
|
+
@decorator_knowngood
|
1248
|
+
def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1249
|
+
precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
|
1250
|
+
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
1092
1251
|
|
1093
|
-
|
1094
|
-
|
1095
|
-
lr = scalar_guard(lr, param)
|
1096
|
-
|
1252
|
+
|
1253
|
+
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1254
|
+
lr = scalar_guard(lr, param[0])
|
1255
|
+
_compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
|
1256
|
+
|
1257
|
+
|
1258
|
+
@decorator_knowngood
|
1259
|
+
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
1260
|
+
md = min_dtype(list(preconds) + [ea])
|
1261
|
+
args = [q.to(md) for q in preconds]
|
1262
|
+
args = args + args + [ea.to(md)]
|
1263
|
+
new = torch.einsum(expr, *args)
|
1264
|
+
return new.to(ea.dtype)
|
1265
|
+
|
1266
|
+
|
1267
|
+
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1268
|
+
precond = psgd_precond_grad(expr, grad, *preconds)
|
1269
|
+
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
1270
|
+
|
1271
|
+
|
1272
|
+
def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1273
|
+
lr = scalar_guard(lr, param[0])
|
1274
|
+
_compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
|
1097
1275
|
|
1098
1276
|
|
1099
1277
|
@decorator_knowngood
|
@@ -1122,7 +1300,7 @@ def caution(g, update):
|
|
1122
1300
|
_compilable_cautioning_(g, update)
|
1123
1301
|
|
1124
1302
|
|
1125
|
-
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=
|
1303
|
+
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
|
1126
1304
|
"""Anneal preconditioner update probability during beginning of training.
|
1127
1305
|
|
1128
1306
|
PSGD benefits from more preconditioner updates at the beginning of training,
|