heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,3 +1,11 @@
1
+ """
2
+
3
+
4
+ Originally from Evan Walters and Omead Pooladzandi, 2024
5
+ Modified under Creative Commons Attribution 4.0 International
6
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
7
+ """
8
+
1
9
  import functools
2
10
  import gc
3
11
  import math
@@ -16,6 +24,7 @@ compile_mode = "max-autotune-no-cudagraphs"
16
24
  dynamic = False
17
25
  compile_mode_recommended_to_none = None
18
26
  zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
27
+ tiny_bf16 = torch.finfo(torch.bfloat16).tiny
19
28
 
20
29
 
21
30
  def decorator(func):
@@ -60,41 +69,34 @@ def warmup(lr: float, step: int, warmup_steps: int):
60
69
 
61
70
  @decorator_knowngood
62
71
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
63
- beta1: Tensor):
72
+ beta1: Tensor, decay: float):
73
+ grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
64
74
  p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
65
75
  for p_, z_, g_ in zip(p32, z32, g32):
76
+ if decay != 0:
77
+ g_.add_(p_, alpha=decay)
66
78
  p_.lerp_(z_, ckp1)
67
- p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
68
- z_.add_(g_, alpha=-lr)
79
+ p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
80
+ z_.add_(g_, alpha=lr)
69
81
  copy_stochastic_list_(p, p32)
70
82
  copy_stochastic_list_(z, z32)
71
83
 
72
84
 
73
- def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
74
- weight = lr ** weight_lr_power * max(step, 1) ** r
75
- weight_sum = weight_sum + weight
76
-
77
- try:
78
- ckp1 = weight / weight_sum
79
- except ZeroDivisionError:
80
- ckp1 = 0
81
- return ckp1, weight_sum
82
-
83
-
84
85
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
85
- z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
86
- weight = lr ** weight_lr_power * max(step, 1) ** r
86
+ z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
87
+ weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
87
88
  weight_sum = weight_sum + weight
88
89
 
89
90
  try:
90
91
  ckp1 = weight / weight_sum
91
92
  except ZeroDivisionError:
92
93
  ckp1 = 0
94
+ ckp1 = 0
93
95
 
94
96
  # These operations update y in-place,
95
97
  # without computing x explicitly.
96
- lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
97
- _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
98
+ lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
99
+ _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
98
100
  return weight_sum
99
101
 
100
102
 
@@ -162,10 +164,13 @@ def beta_debias(beta, step):
162
164
  @decorator_knowngood
163
165
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
164
166
  out: List[Optional[Tensor]]):
165
- torch._foreach_mul_(state, beta2)
166
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
167
- denom = torch._foreach_sqrt(state)
168
- [denom.clamp_(min=eps) for denom in denom]
167
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
168
+ torch._foreach_mul_(s32, beta2)
169
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
170
+ denom = torch._foreach_sqrt(s32)
171
+ [d.clamp_(min=eps) for d in denom]
172
+ copy_stochastic_list_(state, s32)
173
+
169
174
  if out[0] is None:
170
175
  return denom
171
176
 
@@ -179,10 +184,27 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
179
184
  return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
180
185
 
181
186
 
182
- def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
183
- minimum: float = 1e-3, eps: float = 1e-8):
184
- if clip_val <= 0:
185
- return
187
+ @decorator_knowngood
188
+ def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
189
+ out: List[Optional[Tensor]]):
190
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
191
+ torch._foreach_mul_(s32, beta2)
192
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
193
+ denom = torch._foreach_sqrt(s32)
194
+ [d.clamp_(min=eps) for d in denom]
195
+ out = torch._foreach_div(g32, denom)
196
+ copy_stochastic_list_(state, s32)
197
+ return stochastic_round_list_(grad, out)
198
+
199
+
200
+ def scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps):
201
+ grad, exp_avg_sq = list_guard(grad), list_guard(exp_avg_sq)
202
+ beta2, eps = scalar_guard(beta2, grad[0]), scalar_guard(eps, grad[0])
203
+ return _compilable_scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps, grad)
204
+
205
+
206
+ @decorator_knowngood
207
+ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
186
208
  p_norm = torch._foreach_norm(parameters)
187
209
  g_norm = torch._foreach_norm(gradients)
188
210
  torch._foreach_maximum_(p_norm, minimum)
@@ -190,7 +212,16 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
190
212
  torch._foreach_div_(p_norm, g_norm)
191
213
  torch._foreach_mul_(p_norm, clip_val)
192
214
  torch._foreach_minimum_(p_norm, 1)
193
- torch._foreach_mul_(gradients, p_norm)
215
+ return torch._foreach_mul(gradients, p_norm)
216
+
217
+
218
+ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
219
+ minimum: float = 1e-3, eps: float = 1e-8):
220
+ if clip_val <= 0:
221
+ return gradients
222
+ parameters, gradients = list_guard(parameters), list_guard(gradients)
223
+ clip_val = scalar_guard(clip_val, parameters[0])
224
+ return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
194
225
 
195
226
 
196
227
  def is_compiling():
@@ -205,10 +236,7 @@ def set_(dst: Tensor, src: Tensor):
205
236
  return
206
237
  if src.shape != dst.shape:
207
238
  src = src.reshape_as(dst)
208
- if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
209
- dst.set_(src)
210
- else:
211
- dst.copy_(src)
239
+ dst.copy_(src)
212
240
 
213
241
 
214
242
  def clean():
@@ -353,8 +381,6 @@ def get_orthogonal_matrix(mat):
353
381
 
354
382
  Q = torch.flip(Q, [1])
355
383
 
356
- if not float_data:
357
- Q = Q.to(original_device).type(original_type)
358
384
  final.append(Q)
359
385
 
360
386
  return final
@@ -369,6 +395,27 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
369
395
  copy_stochastic_(x_, x32)
370
396
 
371
397
 
398
+ def get_beta1(group):
399
+ beta = None
400
+ if 'beta' in group:
401
+ beta = group['beta']
402
+ if beta is None and 'betas' in group:
403
+ beta = group['betas'][0]
404
+ if beta is None:
405
+ raise ValueError("Beta not found in group.")
406
+ return beta
407
+
408
+
409
+ def get_beta2(group):
410
+ beta = None
411
+ if 'beta2_scale' in group:
412
+ step = max(group.get("step", 1), 1)
413
+ return 1 - step ** -group['beta2_scale']
414
+ if 'betas' in group:
415
+ return group['betas'][1]
416
+ raise ValueError("Beta2 not found in group.")
417
+
418
+
372
419
  def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
373
420
  x, y = list_guard(x), list_guard(y)
374
421
  a = scalar_guard(a, x[0])
@@ -435,35 +482,35 @@ def min_dtype(xs: List[Tensor]):
435
482
  return torch.float32
436
483
 
437
484
 
438
- def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
485
+ def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
439
486
  """
440
487
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
441
488
  """
442
- compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
443
- if state['Q'] is None:
444
- state['Q'] = get_orthogonal_matrix(state['GG'])
489
+ compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
445
490
  if update_precond:
446
- get_orthogonal_matrix_QR(state['GG'], state['Q'], state['exp_avg_sq'])
491
+ get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
447
492
 
448
493
 
449
- def init_preconditioner(grad, state, max_precond_dim=10000, precondition_1d=False):
494
+ def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
450
495
  """
451
496
  Initializes the preconditioner matrices (L and R in the paper).
452
497
  """
453
- state['Q'] = None # Will hold all the eigenbases of the preconditioner.
454
498
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
455
499
  if grad.dim() == 1:
456
- if not precondition_1d or grad.shape[0] > max_precond_dim:
500
+ if precondition_1d or grad.shape[0] > max_precond_dim:
501
+ state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
502
+ else:
457
503
  state['GG'].append([])
458
- return
459
- state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
460
- return
461
504
 
462
- for sh in grad.shape:
463
- if sh > max_precond_dim:
464
- state['GG'].append([])
465
- else:
466
- state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
505
+ else:
506
+ for sh in grad.shape:
507
+ if sh > max_precond_dim:
508
+ state['GG'].append([])
509
+ else:
510
+ state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
511
+
512
+ compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
513
+ state['Q'] = get_orthogonal_matrix(state['GG'])
467
514
 
468
515
 
469
516
  @decorator
@@ -629,74 +676,63 @@ class StatefulOptimizer(torch.optim.Optimizer):
629
676
  return loss
630
677
 
631
678
 
632
-
633
- class ScheduleFree(StatefulOptimizer):
634
- def eval(self):
635
- for group in self.param_groups:
636
- train_mode = group['train_mode']
637
- beta1 = group['beta'] if 'beta' in group else group['betas'][0]
638
- if beta1 > 0 and train_mode:
639
- for p in group['params']:
640
- state = self.state_(p)
641
- if 'z' in state:
642
- # Set p.data to x
643
- z = promote(state['z'])
644
- p32 = promote(p.data)
645
- p32.lerp_(end=z, weight=1 - 1 / beta1)
646
- copy_stochastic_(p.data, p32)
647
- group['train_mode'] = False
648
-
649
- def train(self):
650
- for group in self.param_groups:
651
- train_mode = group['train_mode']
652
- beta1 = group['beta'] if 'beta' in group else group['betas'][0]
653
- if beta1 > 0 and not train_mode:
654
- for p in group['params']:
655
- state = self.state_(p)
656
- if 'z' in state:
657
- z = promote(state['z'])
658
- p32 = promote(p.data)
659
- p32.lerp_(end=z, weight=1 - beta1)
660
- copy_stochastic_(p.data, p32)
661
- group['train_mode'] = True
662
-
663
- def _step(self):
664
- raise NotImplementedError
665
-
666
-
667
679
  def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
668
680
  for t, s in zip(target, source):
669
681
  copy_stochastic_(t, s)
670
682
 
671
683
 
672
684
  @decorator_knowngood
673
- def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
674
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
685
+ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
686
+ step: Tensor):
675
687
  beta1 = beta_debias(beta1, step)
676
688
  beta2 = beta_debias(beta2, step)
677
689
 
678
- g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
690
+ g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
679
691
 
680
- stochastic_lerp_(exp_avg, g32, 1 - beta1)
681
- denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
692
+ [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
693
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
694
+ u32 = torch._foreach_div(exp_avg32, denom)
682
695
 
696
+ copy_stochastic_list_(exp_avg, exp_avg32)
683
697
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
684
- return denom
698
+ return stochastic_round_list_(exp_avg, u32)
685
699
 
686
700
 
687
- def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
688
- beta1: float, beta2: float, step: int):
689
- exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
690
- grad), list_guard(grad_projected)
691
- beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
692
- denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
693
- return denom
701
+ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
702
+ exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
703
+ beta1, beta2, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step,
704
+ exp_avg[0])
705
+ return _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
694
706
 
695
707
 
708
+ @decorator_knowngood
709
+ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
710
+ beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
711
+ caution: bool):
712
+ beta1 = beta_debias(beta1, step)
713
+ beta2 = beta_debias(beta2, step)
714
+
715
+ g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
716
+
717
+ [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
718
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
719
+ u32 = torch._foreach_div(exp_avg32, denom)
720
+
721
+ copy_stochastic_list_(exp_avg, exp_avg32)
722
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
723
+ _compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
724
+
725
+
726
+ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
727
+ beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
728
+ y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad))
729
+ beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
730
+ return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
731
+
696
732
 
697
733
  @decorator_knowngood
698
- def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
699
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
734
+ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: Tensor,
735
+ beta2: Tensor, step: Tensor):
700
736
  beta1 = beta_debias(beta1, step)
701
737
  beta2 = beta_debias(beta2, step)
702
738
 
@@ -709,27 +745,109 @@ def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
709
745
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
710
746
 
711
747
 
712
- def laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
713
- beta1: float, beta2: float, step: int):
748
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: float, beta2: float,
749
+ step: int):
714
750
  exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
715
751
  beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
716
- _compilable_laprop_exp_avg_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
752
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
753
+ return exp_avg
717
754
 
718
755
 
719
756
  @decorator_knowngood
720
- def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
721
- """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
722
- # create a random 16 bit integer
723
- result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
757
+ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
758
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
759
+ decay: Tensor, caution: bool):
760
+ beta1 = beta_debias(beta1, step)
761
+ beta2 = beta_debias(beta2, step)
724
762
 
725
- # add the random number to the lower 16 bit of the mantissa
726
- result.add_(source.view(dtype=torch.int32))
763
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
727
764
 
728
- # mask off the lower 16 bit of the mantissa
765
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
766
+ gp32 = torch._foreach_div(gp32, denom)
767
+ stochastic_lerp_(exp_avg, gp32, 1 - beta1)
768
+ update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
769
+
770
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
771
+
772
+
773
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
774
+ beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
775
+ y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected))
776
+ beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
777
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution)
778
+
779
+
780
+ @decorator_knowngood
781
+ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
782
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
783
+ update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
784
+
785
+ beta1 = beta_debias(beta1, step)
786
+ denom = torch._foreach_sqrt(exp_avg_sq32)
787
+ [denom.clamp_(min=eps) for denom in denom]
788
+ torch._foreach_mul_(exp_avg32, beta1)
789
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
790
+
791
+ beta2 = beta_debias(beta2, step + 1)
792
+ torch._foreach_mul_(exp_avg_sq32, beta2)
793
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
794
+
795
+ copy_stochastic_list_(exp_avg, exp_avg32)
796
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
797
+
798
+
799
+ def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
800
+ y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
801
+ beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
802
+ _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
803
+
804
+
805
+ @decorator_knowngood
806
+ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
807
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
808
+ update = [e.clone() for e in exp_avg]
809
+
810
+ beta1 = beta_debias(beta1, step)
811
+ denom = torch._foreach_sqrt(exp_avg_sq32)
812
+ [denom.clamp_(min=1e-8) for denom in denom]
813
+ torch._foreach_mul_(exp_avg32, beta1)
814
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
815
+
816
+ beta2 = beta_debias(beta2, step + 1)
817
+ torch._foreach_mul_(exp_avg_sq32, beta2)
818
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
819
+
820
+ copy_stochastic_list_(exp_avg, exp_avg32)
821
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
822
+
823
+ return update
824
+
825
+
826
+ def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
827
+ grad, exp_avg_sq, exp_avg = list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
828
+ beta1, beta2, step = scalar_guard(beta1, grad[0]), scalar_guard(beta2, grad[0]), scalar_guard(step, grad[0])
829
+ return _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
830
+
831
+
832
+ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
833
+ return [stochastic_round_(r, s) for r, s in zip(ref, source)]
834
+
835
+
836
+ @decorator_knowngood
837
+ def stochastic_round_(ref: Tensor, source: Tensor):
838
+ if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
839
+ return source
840
+ if ref.dtype != torch.bfloat16:
841
+ return source.to(ref.dtype)
842
+ result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
843
+ result.add_(source.view(dtype=torch.int32))
729
844
  result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
845
+ return result.view(dtype=torch.float32).bfloat16()
846
+
730
847
 
731
- # copy the higher 16 bit into the target tensor
732
- target.copy_(result.view(dtype=torch.float32))
848
+ @decorator_knowngood
849
+ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
850
+ target.copy_(stochastic_round_(target, source))
733
851
 
734
852
 
735
853
  def copy_stochastic_(target: Tensor, source: Tensor):
@@ -902,7 +1020,7 @@ def psgd_lb(A, max_abs):
902
1020
 
903
1021
 
904
1022
  @decorator
905
- def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
1023
+ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
906
1024
  """Update Kronecker product preconditioner Q with pair (V, G)."""
907
1025
  exprA, exprGs, _ = exprs
908
1026
 
@@ -923,10 +1041,10 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
923
1041
  norm = term2.norm(float('inf'))
924
1042
  if q.dim() < 2:
925
1043
  term1 *= q.to(term1.dtype)
926
- term1 /= norm.clamp_(min=tiny)
1044
+ term1 /= norm.clamp_(min=tiny_bf16)
927
1045
  else:
928
1046
  torch.triu(term1, out=term1)
929
- term1 /= psgd_lb(term2, norm).clamp_(tiny)
1047
+ term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
930
1048
  torch.matmul(term1, q, out=term1)
931
1049
  if store_triu_as_line:
932
1050
  term1 = triu_to_line([term1])[0][1]
@@ -935,22 +1053,32 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
935
1053
 
936
1054
 
937
1055
  @decorator_knowngood
938
- def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
939
- """Precondition gradient G with preconditioner Q."""
940
- md = min_dtype(preconds)
941
- out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
942
- if inplace:
943
- set_(grad, out)
944
- return grad
945
- return out.to(grad.dtype)
1056
+ def _compilable_l2_clip_(x):
1057
+ ref = x
1058
+ x = list(map(promote, x))
1059
+ norm = torch._foreach_norm(x)
1060
+ torch._foreach_maximum_(norm, 1e-8)
1061
+ out = torch._foreach_div(x, norm)
1062
+ return stochastic_round_list_(ref, out)
1063
+
946
1064
 
1065
+ def l2_clip_(x):
1066
+ x = list_guard(x)
1067
+ return _compilable_l2_clip_(x)
947
1068
 
948
- def norm_clip_(x, scale=None):
1069
+
1070
+ @decorator_knowngood
1071
+ def _compilable_rmsnorm_clip_(x):
1072
+ x = list(map(promote, x))
949
1073
  norm = torch._foreach_norm(x)
950
- if scale is not None:
951
- torch._foreach_div_(norm, scale)
952
- torch._foreach_div_(x, norm)
953
- return x
1074
+ norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1075
+ torch._foreach_maximum_(norm, 1e-6)
1076
+ return torch._foreach_div(x, norm)
1077
+
1078
+
1079
+ def rmsnorm_clip_(x):
1080
+ x = list_guard(x)
1081
+ return _compilable_rmsnorm_clip_(x)
954
1082
 
955
1083
 
956
1084
  def mu_law_compress(x, mu=127.0):
@@ -990,18 +1118,24 @@ def identity(x):
990
1118
  return x
991
1119
 
992
1120
 
993
- def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
994
- torch._foreach_mul_(grad, 1 / scale)
995
- tanh = torch._foreach_tanh(grad)
996
- torch._foreach_abs_(grad)
997
- torch._foreach_log1p_(grad)
998
- grad = [p.copysign_(t) for t, p in zip(tanh, grad)] # torch doesn't have a foreach copysign
999
- torch._foreach_lerp_(grad, tanh, lerp) # sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9
1000
- torch._foreach_mul_(grad, scale)
1121
+ @decorator_knowngood
1122
+ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1123
+ g32 = list(map(promote, grad))
1124
+ [g.mul_(1 / scale) for g in g32]
1125
+ tanh = torch._foreach_tanh(g32)
1126
+ torch._foreach_abs_(g32)
1127
+ torch._foreach_log1p_(g32)
1128
+ [g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
1129
+
1130
+ torch._foreach_maximum_(g32, -2)
1131
+ torch._foreach_minimum_(g32, 2)
1132
+ return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
1001
1133
 
1002
- torch._foreach_maximum_(grad, -2)
1003
- torch._foreach_minimum_(grad, 2)
1004
- return grad
1134
+
1135
+ def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1136
+ grad = list_guard(grad)
1137
+ lerp, scale = scalar_guard(lerp, grad[0]), scalar_guard(scale, grad[0])
1138
+ return _compilable_trust_region_clip_(grad, lerp, scale)
1005
1139
 
1006
1140
 
1007
1141
  @decorator
@@ -1040,60 +1174,57 @@ def update_triu_(q_state, materialised):
1040
1174
  copy_stochastic_(q, m)
1041
1175
 
1042
1176
 
1043
- class PSGDBase(StatefulOptimizer):
1044
- balance_probability: float = 0.01
1045
-
1046
- def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
1047
- preconditioner_update_probability):
1048
- super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
1049
- self.rng = random.Random(0x1923213)
1050
- self._tiny = torch.finfo(torch.bfloat16).tiny
1051
- if clip_fn is None:
1052
- clip_fn = identity
1053
- if preconditioner_update_probability is None:
1054
- preconditioner_update_probability = precond_update_prob_schedule()
1055
- self.clip_fn = clip_fn
1056
- self.preconditioner_update_probability = preconditioner_update_probability
1057
-
1058
- def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
1059
- group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1060
- if prob is None:
1061
- prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
1062
- if group['stochastic_schedule']:
1063
- return self.rng.random() < prob
1064
- cumulative_prob = group.get(name, 0)
1065
- group[name] = cumulative_prob + prob
1066
- return int(group[name]) > int(cumulative_prob)
1067
-
1068
- def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
1069
- for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
1070
- psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
1071
-
1072
- if self.should_update(group, self.balance_probability, "balance_prob"):
1073
- for g, q in zip(grad_list, original_q if original_q else q_list):
1074
- if g.dim() > 1:
1075
- if store_triu_as_line:
1076
- psgd_balance_Q([q_ for _, q_ in q])
1077
- else:
1078
- psgd_balance_Q(q)
1079
-
1080
-
1081
- # TODO: Figure out why this sometimes crashes
1082
- # @decorator_knowngood
1083
- def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
1084
- clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
1177
+ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1178
+ name: str = 'cumulative_prob'):
1179
+ group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1180
+ if not isinstance(prob, float):
1181
+ prob = prob(group[f'{name}_prob_step'])
1182
+ if group['stochastic_schedule']:
1183
+ return rng.random() < prob
1184
+ cumulative_prob = state.get(name, 0)
1185
+ group[name] = cumulative_prob + prob
1186
+ return int(group[name]) > int(cumulative_prob)
1187
+
1188
+
1189
+ @decorator_knowngood
1190
+ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1085
1191
  md = min_dtype(list(cached_q) + [ea])
1086
1192
  args = [q.to(md) for q in cached_q]
1087
1193
  args = args + [ea.to(md)]
1088
1194
  new = torch.einsum(expr, *args)
1089
- new = new.to(torch.float32)
1090
- _compilable_update_([param], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
1195
+ if cast:
1196
+ return new.to(ea.dtype)
1197
+ return new
1198
+
1199
+
1200
+ @decorator_knowngood
1201
+ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1202
+ precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1203
+ update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1204
+
1205
+
1206
+ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1207
+ lr = scalar_guard(lr, param[0])
1208
+ _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
1209
+
1210
+
1211
+ @decorator_knowngood
1212
+ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1213
+ md = min_dtype(list(preconds) + [ea])
1214
+ args = [q.to(md) for q in preconds]
1215
+ args = args + args + [ea.to(md)]
1216
+ new = torch.einsum(expr, *args)
1217
+ return new.to(ea.dtype)
1218
+
1091
1219
 
1220
+ def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1221
+ precond = psgd_precond_grad(expr, grad, *preconds)
1222
+ update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1092
1223
 
1093
- def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
1094
- clip_fn, caution, grad):
1095
- lr = scalar_guard(lr, param)
1096
- _compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
1224
+
1225
+ def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1226
+ lr = scalar_guard(lr, param[0])
1227
+ _compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
1097
1228
 
1098
1229
 
1099
1230
  @decorator_knowngood
@@ -1122,7 +1253,7 @@ def caution(g, update):
1122
1253
  _compilable_cautioning_(g, update)
1123
1254
 
1124
1255
 
1125
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
1256
+ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
1126
1257
  """Anneal preconditioner update probability during beginning of training.
1127
1258
 
1128
1259
  PSGD benefits from more preconditioner updates at the beginning of training,