heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -40
- heavyball/chainable.py +475 -0
- heavyball/utils.py +318 -187
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/METADATA +4 -3
- heavyball-1.0.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -93
- heavyball/foreach_solp.py +0 -89
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -101
- heavyball/palm_foreach_solp.py +0 -98
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1.dist-info/RECORD +0 -28
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
|
4
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
5
|
+
Modified under Creative Commons Attribution 4.0 International
|
6
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
7
|
+
"""
|
8
|
+
|
1
9
|
import functools
|
2
10
|
import gc
|
3
11
|
import math
|
@@ -16,6 +24,7 @@ compile_mode = "max-autotune-no-cudagraphs"
|
|
16
24
|
dynamic = False
|
17
25
|
compile_mode_recommended_to_none = None
|
18
26
|
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
|
27
|
+
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
19
28
|
|
20
29
|
|
21
30
|
def decorator(func):
|
@@ -60,41 +69,34 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
60
69
|
|
61
70
|
@decorator_knowngood
|
62
71
|
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
|
63
|
-
beta1: Tensor):
|
72
|
+
beta1: Tensor, decay: float):
|
73
|
+
grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
|
64
74
|
p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
|
65
75
|
for p_, z_, g_ in zip(p32, z32, g32):
|
76
|
+
if decay != 0:
|
77
|
+
g_.add_(p_, alpha=decay)
|
66
78
|
p_.lerp_(z_, ckp1)
|
67
|
-
p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1)
|
68
|
-
z_.add_(g_, alpha
|
79
|
+
p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
|
80
|
+
z_.add_(g_, alpha=lr)
|
69
81
|
copy_stochastic_list_(p, p32)
|
70
82
|
copy_stochastic_list_(z, z32)
|
71
83
|
|
72
84
|
|
73
|
-
def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
74
|
-
weight = lr ** weight_lr_power * max(step, 1) ** r
|
75
|
-
weight_sum = weight_sum + weight
|
76
|
-
|
77
|
-
try:
|
78
|
-
ckp1 = weight / weight_sum
|
79
|
-
except ZeroDivisionError:
|
80
|
-
ckp1 = 0
|
81
|
-
return ckp1, weight_sum
|
82
|
-
|
83
|
-
|
84
85
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
85
|
-
z: List[Tensor], grad:
|
86
|
-
weight = lr ** weight_lr_power * max(step, 1) ** r
|
86
|
+
z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
|
87
|
+
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
87
88
|
weight_sum = weight_sum + weight
|
88
89
|
|
89
90
|
try:
|
90
91
|
ckp1 = weight / weight_sum
|
91
92
|
except ZeroDivisionError:
|
92
93
|
ckp1 = 0
|
94
|
+
ckp1 = 0
|
93
95
|
|
94
96
|
# These operations update y in-place,
|
95
97
|
# without computing x explicitly.
|
96
|
-
lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
|
97
|
-
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
|
98
|
+
lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
|
99
|
+
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
|
98
100
|
return weight_sum
|
99
101
|
|
100
102
|
|
@@ -162,10 +164,13 @@ def beta_debias(beta, step):
|
|
162
164
|
@decorator_knowngood
|
163
165
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
164
166
|
out: List[Optional[Tensor]]):
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
167
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
168
|
+
torch._foreach_mul_(s32, beta2)
|
169
|
+
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
170
|
+
denom = torch._foreach_sqrt(s32)
|
171
|
+
[d.clamp_(min=eps) for d in denom]
|
172
|
+
copy_stochastic_list_(state, s32)
|
173
|
+
|
169
174
|
if out[0] is None:
|
170
175
|
return denom
|
171
176
|
|
@@ -179,10 +184,27 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
179
184
|
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
180
185
|
|
181
186
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
187
|
+
@decorator_knowngood
|
188
|
+
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
189
|
+
out: List[Optional[Tensor]]):
|
190
|
+
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
191
|
+
torch._foreach_mul_(s32, beta2)
|
192
|
+
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
193
|
+
denom = torch._foreach_sqrt(s32)
|
194
|
+
[d.clamp_(min=eps) for d in denom]
|
195
|
+
out = torch._foreach_div(g32, denom)
|
196
|
+
copy_stochastic_list_(state, s32)
|
197
|
+
return stochastic_round_list_(grad, out)
|
198
|
+
|
199
|
+
|
200
|
+
def scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps):
|
201
|
+
grad, exp_avg_sq = list_guard(grad), list_guard(exp_avg_sq)
|
202
|
+
beta2, eps = scalar_guard(beta2, grad[0]), scalar_guard(eps, grad[0])
|
203
|
+
return _compilable_scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps, grad)
|
204
|
+
|
205
|
+
|
206
|
+
@decorator_knowngood
|
207
|
+
def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
|
186
208
|
p_norm = torch._foreach_norm(parameters)
|
187
209
|
g_norm = torch._foreach_norm(gradients)
|
188
210
|
torch._foreach_maximum_(p_norm, minimum)
|
@@ -190,7 +212,16 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
|
|
190
212
|
torch._foreach_div_(p_norm, g_norm)
|
191
213
|
torch._foreach_mul_(p_norm, clip_val)
|
192
214
|
torch._foreach_minimum_(p_norm, 1)
|
193
|
-
torch.
|
215
|
+
return torch._foreach_mul(gradients, p_norm)
|
216
|
+
|
217
|
+
|
218
|
+
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
219
|
+
minimum: float = 1e-3, eps: float = 1e-8):
|
220
|
+
if clip_val <= 0:
|
221
|
+
return gradients
|
222
|
+
parameters, gradients = list_guard(parameters), list_guard(gradients)
|
223
|
+
clip_val = scalar_guard(clip_val, parameters[0])
|
224
|
+
return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
|
194
225
|
|
195
226
|
|
196
227
|
def is_compiling():
|
@@ -205,10 +236,7 @@ def set_(dst: Tensor, src: Tensor):
|
|
205
236
|
return
|
206
237
|
if src.shape != dst.shape:
|
207
238
|
src = src.reshape_as(dst)
|
208
|
-
|
209
|
-
dst.set_(src)
|
210
|
-
else:
|
211
|
-
dst.copy_(src)
|
239
|
+
dst.copy_(src)
|
212
240
|
|
213
241
|
|
214
242
|
def clean():
|
@@ -353,8 +381,6 @@ def get_orthogonal_matrix(mat):
|
|
353
381
|
|
354
382
|
Q = torch.flip(Q, [1])
|
355
383
|
|
356
|
-
if not float_data:
|
357
|
-
Q = Q.to(original_device).type(original_type)
|
358
384
|
final.append(Q)
|
359
385
|
|
360
386
|
return final
|
@@ -369,6 +395,27 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
369
395
|
copy_stochastic_(x_, x32)
|
370
396
|
|
371
397
|
|
398
|
+
def get_beta1(group):
|
399
|
+
beta = None
|
400
|
+
if 'beta' in group:
|
401
|
+
beta = group['beta']
|
402
|
+
if beta is None and 'betas' in group:
|
403
|
+
beta = group['betas'][0]
|
404
|
+
if beta is None:
|
405
|
+
raise ValueError("Beta not found in group.")
|
406
|
+
return beta
|
407
|
+
|
408
|
+
|
409
|
+
def get_beta2(group):
|
410
|
+
beta = None
|
411
|
+
if 'beta2_scale' in group:
|
412
|
+
step = max(group.get("step", 1), 1)
|
413
|
+
return 1 - step ** -group['beta2_scale']
|
414
|
+
if 'betas' in group:
|
415
|
+
return group['betas'][1]
|
416
|
+
raise ValueError("Beta2 not found in group.")
|
417
|
+
|
418
|
+
|
372
419
|
def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
373
420
|
x, y = list_guard(x), list_guard(y)
|
374
421
|
a = scalar_guard(a, x[0])
|
@@ -435,35 +482,35 @@ def min_dtype(xs: List[Tensor]):
|
|
435
482
|
return torch.float32
|
436
483
|
|
437
484
|
|
438
|
-
def update_preconditioner(grad,
|
485
|
+
def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
|
439
486
|
"""
|
440
487
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
441
488
|
"""
|
442
|
-
compute_ggt(grad,
|
443
|
-
if state['Q'] is None:
|
444
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
489
|
+
compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
|
445
490
|
if update_precond:
|
446
|
-
get_orthogonal_matrix_QR(
|
491
|
+
get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
|
447
492
|
|
448
493
|
|
449
|
-
def init_preconditioner(grad, state, max_precond_dim=10000, precondition_1d=False):
|
494
|
+
def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
|
450
495
|
"""
|
451
496
|
Initializes the preconditioner matrices (L and R in the paper).
|
452
497
|
"""
|
453
|
-
state['Q'] = None # Will hold all the eigenbases of the preconditioner.
|
454
498
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
455
499
|
if grad.dim() == 1:
|
456
|
-
if
|
500
|
+
if precondition_1d or grad.shape[0] > max_precond_dim:
|
501
|
+
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
502
|
+
else:
|
457
503
|
state['GG'].append([])
|
458
|
-
return
|
459
|
-
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
460
|
-
return
|
461
504
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
505
|
+
else:
|
506
|
+
for sh in grad.shape:
|
507
|
+
if sh > max_precond_dim:
|
508
|
+
state['GG'].append([])
|
509
|
+
else:
|
510
|
+
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
511
|
+
|
512
|
+
compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
|
513
|
+
state['Q'] = get_orthogonal_matrix(state['GG'])
|
467
514
|
|
468
515
|
|
469
516
|
@decorator
|
@@ -629,74 +676,63 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
629
676
|
return loss
|
630
677
|
|
631
678
|
|
632
|
-
|
633
|
-
class ScheduleFree(StatefulOptimizer):
|
634
|
-
def eval(self):
|
635
|
-
for group in self.param_groups:
|
636
|
-
train_mode = group['train_mode']
|
637
|
-
beta1 = group['beta'] if 'beta' in group else group['betas'][0]
|
638
|
-
if beta1 > 0 and train_mode:
|
639
|
-
for p in group['params']:
|
640
|
-
state = self.state_(p)
|
641
|
-
if 'z' in state:
|
642
|
-
# Set p.data to x
|
643
|
-
z = promote(state['z'])
|
644
|
-
p32 = promote(p.data)
|
645
|
-
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
646
|
-
copy_stochastic_(p.data, p32)
|
647
|
-
group['train_mode'] = False
|
648
|
-
|
649
|
-
def train(self):
|
650
|
-
for group in self.param_groups:
|
651
|
-
train_mode = group['train_mode']
|
652
|
-
beta1 = group['beta'] if 'beta' in group else group['betas'][0]
|
653
|
-
if beta1 > 0 and not train_mode:
|
654
|
-
for p in group['params']:
|
655
|
-
state = self.state_(p)
|
656
|
-
if 'z' in state:
|
657
|
-
z = promote(state['z'])
|
658
|
-
p32 = promote(p.data)
|
659
|
-
p32.lerp_(end=z, weight=1 - beta1)
|
660
|
-
copy_stochastic_(p.data, p32)
|
661
|
-
group['train_mode'] = True
|
662
|
-
|
663
|
-
def _step(self):
|
664
|
-
raise NotImplementedError
|
665
|
-
|
666
|
-
|
667
679
|
def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
668
680
|
for t, s in zip(target, source):
|
669
681
|
copy_stochastic_(t, s)
|
670
682
|
|
671
683
|
|
672
684
|
@decorator_knowngood
|
673
|
-
def
|
674
|
-
|
685
|
+
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
686
|
+
step: Tensor):
|
675
687
|
beta1 = beta_debias(beta1, step)
|
676
688
|
beta2 = beta_debias(beta2, step)
|
677
689
|
|
678
|
-
g32,
|
690
|
+
g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
|
679
691
|
|
680
|
-
|
681
|
-
denom = exp_avg_sq_(exp_avg_sq32,
|
692
|
+
[ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
|
693
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
|
694
|
+
u32 = torch._foreach_div(exp_avg32, denom)
|
682
695
|
|
696
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
683
697
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
684
|
-
return
|
698
|
+
return stochastic_round_list_(exp_avg, u32)
|
685
699
|
|
686
700
|
|
687
|
-
def
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
693
|
-
return denom
|
701
|
+
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
702
|
+
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
703
|
+
beta1, beta2, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step,
|
704
|
+
exp_avg[0])
|
705
|
+
return _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
694
706
|
|
695
707
|
|
708
|
+
@decorator_knowngood
|
709
|
+
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
710
|
+
beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
|
711
|
+
caution: bool):
|
712
|
+
beta1 = beta_debias(beta1, step)
|
713
|
+
beta2 = beta_debias(beta2, step)
|
714
|
+
|
715
|
+
g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
|
716
|
+
|
717
|
+
[ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
|
718
|
+
denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
|
719
|
+
u32 = torch._foreach_div(exp_avg32, denom)
|
720
|
+
|
721
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
722
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
723
|
+
_compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
|
724
|
+
|
725
|
+
|
726
|
+
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
|
727
|
+
beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
|
728
|
+
y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad))
|
729
|
+
beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
|
730
|
+
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
|
731
|
+
|
696
732
|
|
697
733
|
@decorator_knowngood
|
698
|
-
def
|
699
|
-
|
734
|
+
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: Tensor,
|
735
|
+
beta2: Tensor, step: Tensor):
|
700
736
|
beta1 = beta_debias(beta1, step)
|
701
737
|
beta2 = beta_debias(beta2, step)
|
702
738
|
|
@@ -709,27 +745,109 @@ def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
|
709
745
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
710
746
|
|
711
747
|
|
712
|
-
def
|
713
|
-
|
748
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: float, beta2: float,
|
749
|
+
step: int):
|
714
750
|
exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
|
715
751
|
beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
|
716
|
-
|
752
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
|
753
|
+
return exp_avg
|
717
754
|
|
718
755
|
|
719
756
|
@decorator_knowngood
|
720
|
-
def
|
721
|
-
|
722
|
-
|
723
|
-
|
757
|
+
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
758
|
+
grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
|
759
|
+
decay: Tensor, caution: bool):
|
760
|
+
beta1 = beta_debias(beta1, step)
|
761
|
+
beta2 = beta_debias(beta2, step)
|
724
762
|
|
725
|
-
|
726
|
-
result.add_(source.view(dtype=torch.int32))
|
763
|
+
gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
|
727
764
|
|
728
|
-
|
765
|
+
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
766
|
+
gp32 = torch._foreach_div(gp32, denom)
|
767
|
+
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
768
|
+
update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
|
769
|
+
|
770
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
771
|
+
|
772
|
+
|
773
|
+
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
|
774
|
+
beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
|
775
|
+
y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected))
|
776
|
+
beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
|
777
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution)
|
778
|
+
|
779
|
+
|
780
|
+
@decorator_knowngood
|
781
|
+
def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
782
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
783
|
+
update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
|
784
|
+
|
785
|
+
beta1 = beta_debias(beta1, step)
|
786
|
+
denom = torch._foreach_sqrt(exp_avg_sq32)
|
787
|
+
[denom.clamp_(min=eps) for denom in denom]
|
788
|
+
torch._foreach_mul_(exp_avg32, beta1)
|
789
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
790
|
+
|
791
|
+
beta2 = beta_debias(beta2, step + 1)
|
792
|
+
torch._foreach_mul_(exp_avg_sq32, beta2)
|
793
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
794
|
+
|
795
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
796
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
797
|
+
|
798
|
+
|
799
|
+
def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
800
|
+
y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
|
801
|
+
beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
|
802
|
+
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
803
|
+
|
804
|
+
|
805
|
+
@decorator_knowngood
|
806
|
+
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
807
|
+
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
808
|
+
update = [e.clone() for e in exp_avg]
|
809
|
+
|
810
|
+
beta1 = beta_debias(beta1, step)
|
811
|
+
denom = torch._foreach_sqrt(exp_avg_sq32)
|
812
|
+
[denom.clamp_(min=1e-8) for denom in denom]
|
813
|
+
torch._foreach_mul_(exp_avg32, beta1)
|
814
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
815
|
+
|
816
|
+
beta2 = beta_debias(beta2, step + 1)
|
817
|
+
torch._foreach_mul_(exp_avg_sq32, beta2)
|
818
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
819
|
+
|
820
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
821
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
822
|
+
|
823
|
+
return update
|
824
|
+
|
825
|
+
|
826
|
+
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
827
|
+
grad, exp_avg_sq, exp_avg = list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
|
828
|
+
beta1, beta2, step = scalar_guard(beta1, grad[0]), scalar_guard(beta2, grad[0]), scalar_guard(step, grad[0])
|
829
|
+
return _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
830
|
+
|
831
|
+
|
832
|
+
def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
|
833
|
+
return [stochastic_round_(r, s) for r, s in zip(ref, source)]
|
834
|
+
|
835
|
+
|
836
|
+
@decorator_knowngood
|
837
|
+
def stochastic_round_(ref: Tensor, source: Tensor):
|
838
|
+
if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
|
839
|
+
return source
|
840
|
+
if ref.dtype != torch.bfloat16:
|
841
|
+
return source.to(ref.dtype)
|
842
|
+
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
843
|
+
result.add_(source.view(dtype=torch.int32))
|
729
844
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
845
|
+
return result.view(dtype=torch.float32).bfloat16()
|
846
|
+
|
730
847
|
|
731
|
-
|
732
|
-
|
848
|
+
@decorator_knowngood
|
849
|
+
def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
850
|
+
target.copy_(stochastic_round_(target, source))
|
733
851
|
|
734
852
|
|
735
853
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
@@ -902,7 +1020,7 @@ def psgd_lb(A, max_abs):
|
|
902
1020
|
|
903
1021
|
|
904
1022
|
@decorator
|
905
|
-
def psgd_update_precond(Q, exprs, G, precond_lr,
|
1023
|
+
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
906
1024
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
907
1025
|
exprA, exprGs, _ = exprs
|
908
1026
|
|
@@ -923,10 +1041,10 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
923
1041
|
norm = term2.norm(float('inf'))
|
924
1042
|
if q.dim() < 2:
|
925
1043
|
term1 *= q.to(term1.dtype)
|
926
|
-
term1 /= norm.clamp_(min=
|
1044
|
+
term1 /= norm.clamp_(min=tiny_bf16)
|
927
1045
|
else:
|
928
1046
|
torch.triu(term1, out=term1)
|
929
|
-
term1 /= psgd_lb(term2, norm).clamp_(
|
1047
|
+
term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
|
930
1048
|
torch.matmul(term1, q, out=term1)
|
931
1049
|
if store_triu_as_line:
|
932
1050
|
term1 = triu_to_line([term1])[0][1]
|
@@ -935,22 +1053,32 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
935
1053
|
|
936
1054
|
|
937
1055
|
@decorator_knowngood
|
938
|
-
def
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
1056
|
+
def _compilable_l2_clip_(x):
|
1057
|
+
ref = x
|
1058
|
+
x = list(map(promote, x))
|
1059
|
+
norm = torch._foreach_norm(x)
|
1060
|
+
torch._foreach_maximum_(norm, 1e-8)
|
1061
|
+
out = torch._foreach_div(x, norm)
|
1062
|
+
return stochastic_round_list_(ref, out)
|
1063
|
+
|
946
1064
|
|
1065
|
+
def l2_clip_(x):
|
1066
|
+
x = list_guard(x)
|
1067
|
+
return _compilable_l2_clip_(x)
|
947
1068
|
|
948
|
-
|
1069
|
+
|
1070
|
+
@decorator_knowngood
|
1071
|
+
def _compilable_rmsnorm_clip_(x):
|
1072
|
+
x = list(map(promote, x))
|
949
1073
|
norm = torch._foreach_norm(x)
|
950
|
-
|
951
|
-
|
952
|
-
torch.
|
953
|
-
|
1074
|
+
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1075
|
+
torch._foreach_maximum_(norm, 1e-6)
|
1076
|
+
return torch._foreach_div(x, norm)
|
1077
|
+
|
1078
|
+
|
1079
|
+
def rmsnorm_clip_(x):
|
1080
|
+
x = list_guard(x)
|
1081
|
+
return _compilable_rmsnorm_clip_(x)
|
954
1082
|
|
955
1083
|
|
956
1084
|
def mu_law_compress(x, mu=127.0):
|
@@ -990,18 +1118,24 @@ def identity(x):
|
|
990
1118
|
return x
|
991
1119
|
|
992
1120
|
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
torch.
|
998
|
-
|
999
|
-
torch.
|
1000
|
-
|
1121
|
+
@decorator_knowngood
|
1122
|
+
def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
1123
|
+
g32 = list(map(promote, grad))
|
1124
|
+
[g.mul_(1 / scale) for g in g32]
|
1125
|
+
tanh = torch._foreach_tanh(g32)
|
1126
|
+
torch._foreach_abs_(g32)
|
1127
|
+
torch._foreach_log1p_(g32)
|
1128
|
+
[g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
|
1129
|
+
|
1130
|
+
torch._foreach_maximum_(g32, -2)
|
1131
|
+
torch._foreach_minimum_(g32, 2)
|
1132
|
+
return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
|
1001
1133
|
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1134
|
+
|
1135
|
+
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1136
|
+
grad = list_guard(grad)
|
1137
|
+
lerp, scale = scalar_guard(lerp, grad[0]), scalar_guard(scale, grad[0])
|
1138
|
+
return _compilable_trust_region_clip_(grad, lerp, scale)
|
1005
1139
|
|
1006
1140
|
|
1007
1141
|
@decorator
|
@@ -1040,60 +1174,57 @@ def update_triu_(q_state, materialised):
|
|
1040
1174
|
copy_stochastic_(q, m)
|
1041
1175
|
|
1042
1176
|
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
|
1059
|
-
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
1060
|
-
if prob is None:
|
1061
|
-
prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
|
1062
|
-
if group['stochastic_schedule']:
|
1063
|
-
return self.rng.random() < prob
|
1064
|
-
cumulative_prob = group.get(name, 0)
|
1065
|
-
group[name] = cumulative_prob + prob
|
1066
|
-
return int(group[name]) > int(cumulative_prob)
|
1067
|
-
|
1068
|
-
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
|
1069
|
-
for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
|
1070
|
-
psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
|
1071
|
-
|
1072
|
-
if self.should_update(group, self.balance_probability, "balance_prob"):
|
1073
|
-
for g, q in zip(grad_list, original_q if original_q else q_list):
|
1074
|
-
if g.dim() > 1:
|
1075
|
-
if store_triu_as_line:
|
1076
|
-
psgd_balance_Q([q_ for _, q_ in q])
|
1077
|
-
else:
|
1078
|
-
psgd_balance_Q(q)
|
1079
|
-
|
1080
|
-
|
1081
|
-
# TODO: Figure out why this sometimes crashes
|
1082
|
-
# @decorator_knowngood
|
1083
|
-
def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
|
1084
|
-
clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
|
1177
|
+
def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
|
1178
|
+
name: str = 'cumulative_prob'):
|
1179
|
+
group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
|
1180
|
+
if not isinstance(prob, float):
|
1181
|
+
prob = prob(group[f'{name}_prob_step'])
|
1182
|
+
if group['stochastic_schedule']:
|
1183
|
+
return rng.random() < prob
|
1184
|
+
cumulative_prob = state.get(name, 0)
|
1185
|
+
group[name] = cumulative_prob + prob
|
1186
|
+
return int(group[name]) > int(cumulative_prob)
|
1187
|
+
|
1188
|
+
|
1189
|
+
@decorator_knowngood
|
1190
|
+
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
|
1085
1191
|
md = min_dtype(list(cached_q) + [ea])
|
1086
1192
|
args = [q.to(md) for q in cached_q]
|
1087
1193
|
args = args + [ea.to(md)]
|
1088
1194
|
new = torch.einsum(expr, *args)
|
1089
|
-
|
1090
|
-
|
1195
|
+
if cast:
|
1196
|
+
return new.to(ea.dtype)
|
1197
|
+
return new
|
1198
|
+
|
1199
|
+
|
1200
|
+
@decorator_knowngood
|
1201
|
+
def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1202
|
+
precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
|
1203
|
+
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
1204
|
+
|
1205
|
+
|
1206
|
+
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1207
|
+
lr = scalar_guard(lr, param[0])
|
1208
|
+
_compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
|
1209
|
+
|
1210
|
+
|
1211
|
+
@decorator_knowngood
|
1212
|
+
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
1213
|
+
md = min_dtype(list(preconds) + [ea])
|
1214
|
+
args = [q.to(md) for q in preconds]
|
1215
|
+
args = args + args + [ea.to(md)]
|
1216
|
+
new = torch.einsum(expr, *args)
|
1217
|
+
return new.to(ea.dtype)
|
1218
|
+
|
1091
1219
|
|
1220
|
+
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1221
|
+
precond = psgd_precond_grad(expr, grad, *preconds)
|
1222
|
+
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
1092
1223
|
|
1093
|
-
|
1094
|
-
|
1095
|
-
lr = scalar_guard(lr, param)
|
1096
|
-
|
1224
|
+
|
1225
|
+
def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1226
|
+
lr = scalar_guard(lr, param[0])
|
1227
|
+
_compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
|
1097
1228
|
|
1098
1229
|
|
1099
1230
|
@decorator_knowngood
|
@@ -1122,7 +1253,7 @@ def caution(g, update):
|
|
1122
1253
|
_compilable_cautioning_(g, update)
|
1123
1254
|
|
1124
1255
|
|
1125
|
-
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=
|
1256
|
+
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
|
1126
1257
|
"""Anneal preconditioner update probability during beginning of training.
|
1127
1258
|
|
1128
1259
|
PSGD benefits from more preconditioner updates at the beginning of training,
|