heavyball 0.25.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,3 +1,11 @@
1
+ """
2
+
3
+
4
+ Originally from Evan Walters and Omead Pooladzandi, 2024
5
+ Modified under Creative Commons Attribution 4.0 International
6
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
7
+ """
8
+
1
9
  import functools
2
10
  import gc
3
11
  import math
@@ -16,6 +24,7 @@ compile_mode = "max-autotune-no-cudagraphs"
16
24
  dynamic = False
17
25
  compile_mode_recommended_to_none = None
18
26
  zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
27
+ tiny_bf16 = torch.finfo(torch.bfloat16).tiny
19
28
 
20
29
 
21
30
  def decorator(func):
@@ -60,41 +69,34 @@ def warmup(lr: float, step: int, warmup_steps: int):
60
69
 
61
70
  @decorator_knowngood
62
71
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
63
- beta1: Tensor):
72
+ beta1: Tensor, decay: float):
73
+ grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
64
74
  p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
65
75
  for p_, z_, g_ in zip(p32, z32, g32):
76
+ if decay != 0:
77
+ g_.add_(p_, alpha=decay)
66
78
  p_.lerp_(z_, ckp1)
67
- p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
68
- z_.add_(g_, alpha=-lr)
79
+ p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
80
+ z_.add_(g_, alpha=lr)
69
81
  copy_stochastic_list_(p, p32)
70
82
  copy_stochastic_list_(z, z32)
71
83
 
72
84
 
73
- def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
74
- weight = lr ** weight_lr_power * max(step, 1) ** r
75
- weight_sum = weight_sum + weight
76
-
77
- try:
78
- ckp1 = weight / weight_sum
79
- except ZeroDivisionError:
80
- ckp1 = 0
81
- return ckp1, weight_sum
82
-
83
-
84
85
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
85
- z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
86
- weight = lr ** weight_lr_power * max(step, 1) ** r
86
+ z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
87
+ weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
87
88
  weight_sum = weight_sum + weight
88
89
 
89
90
  try:
90
91
  ckp1 = weight / weight_sum
91
92
  except ZeroDivisionError:
92
93
  ckp1 = 0
94
+ ckp1 = 0
93
95
 
94
96
  # These operations update y in-place,
95
97
  # without computing x explicitly.
96
- lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
97
- _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
98
+ lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
99
+ _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
98
100
  return weight_sum
99
101
 
100
102
 
@@ -162,10 +164,13 @@ def beta_debias(beta, step):
162
164
  @decorator_knowngood
163
165
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
164
166
  out: List[Optional[Tensor]]):
165
- torch._foreach_mul_(state, beta2)
166
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
167
- denom = torch._foreach_sqrt(state)
168
- [denom.clamp_(min=eps) for denom in denom]
167
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
168
+ torch._foreach_mul_(s32, beta2)
169
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
170
+ denom = torch._foreach_sqrt(s32)
171
+ [d.clamp_(min=eps) for d in denom]
172
+ copy_stochastic_list_(state, s32)
173
+
169
174
  if out[0] is None:
170
175
  return denom
171
176
 
@@ -179,10 +184,27 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
179
184
  return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
180
185
 
181
186
 
182
- def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
183
- minimum: float = 1e-3, eps: float = 1e-8):
184
- if clip_val <= 0:
185
- return
187
+ @decorator_knowngood
188
+ def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
189
+ out: List[Optional[Tensor]]):
190
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
191
+ torch._foreach_mul_(s32, beta2)
192
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
193
+ denom = torch._foreach_sqrt(s32)
194
+ [d.clamp_(min=eps) for d in denom]
195
+ out = torch._foreach_div(g32, denom)
196
+ copy_stochastic_list_(state, s32)
197
+ return stochastic_round_list_(grad, out)
198
+
199
+
200
+ def scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps):
201
+ grad, exp_avg_sq = list_guard(grad), list_guard(exp_avg_sq)
202
+ beta2, eps = scalar_guard(beta2, grad[0]), scalar_guard(eps, grad[0])
203
+ return _compilable_scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps, grad)
204
+
205
+
206
+ @decorator_knowngood
207
+ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
186
208
  p_norm = torch._foreach_norm(parameters)
187
209
  g_norm = torch._foreach_norm(gradients)
188
210
  torch._foreach_maximum_(p_norm, minimum)
@@ -190,7 +212,16 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
190
212
  torch._foreach_div_(p_norm, g_norm)
191
213
  torch._foreach_mul_(p_norm, clip_val)
192
214
  torch._foreach_minimum_(p_norm, 1)
193
- torch._foreach_mul_(gradients, p_norm)
215
+ return torch._foreach_mul(gradients, p_norm)
216
+
217
+
218
+ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
219
+ minimum: float = 1e-3, eps: float = 1e-8):
220
+ if clip_val <= 0:
221
+ return gradients
222
+ parameters, gradients = list_guard(parameters), list_guard(gradients)
223
+ clip_val = scalar_guard(clip_val, parameters[0])
224
+ return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
194
225
 
195
226
 
196
227
  def is_compiling():
@@ -205,10 +236,7 @@ def set_(dst: Tensor, src: Tensor):
205
236
  return
206
237
  if src.shape != dst.shape:
207
238
  src = src.reshape_as(dst)
208
- if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
209
- dst.set_(src)
210
- else:
211
- dst.copy_(src)
239
+ dst.copy_(src)
212
240
 
213
241
 
214
242
  def clean():
@@ -353,8 +381,6 @@ def get_orthogonal_matrix(mat):
353
381
 
354
382
  Q = torch.flip(Q, [1])
355
383
 
356
- if not float_data:
357
- Q = Q.to(original_device).type(original_type)
358
384
  final.append(Q)
359
385
 
360
386
  return final
@@ -369,6 +395,27 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
369
395
  copy_stochastic_(x_, x32)
370
396
 
371
397
 
398
+ def get_beta1(group):
399
+ beta = None
400
+ if 'beta' in group:
401
+ beta = group['beta']
402
+ if beta is None and 'betas' in group:
403
+ beta = group['betas'][0]
404
+ if beta is None:
405
+ raise ValueError("Beta not found in group.")
406
+ return beta
407
+
408
+
409
+ def get_beta2(group):
410
+ beta = None
411
+ if 'beta2_scale' in group:
412
+ step = max(group.get("step", 1), 1)
413
+ return 1 - step ** -group['beta2_scale']
414
+ if 'betas' in group:
415
+ return group['betas'][1]
416
+ raise ValueError("Beta2 not found in group.")
417
+
418
+
372
419
  def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
373
420
  x, y = list_guard(x), list_guard(y)
374
421
  a = scalar_guard(a, x[0])
@@ -435,35 +482,35 @@ def min_dtype(xs: List[Tensor]):
435
482
  return torch.float32
436
483
 
437
484
 
438
- def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
485
+ def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
439
486
  """
440
487
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
441
488
  """
442
- compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
443
- if state['Q'] is None:
444
- state['Q'] = get_orthogonal_matrix(state['GG'])
489
+ compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
445
490
  if update_precond:
446
- get_orthogonal_matrix_QR(state['GG'], state['Q'], state['exp_avg_sq'])
491
+ get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
447
492
 
448
493
 
449
- def init_preconditioner(grad, state, max_precond_dim=10000, precondition_1d=False):
494
+ def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
450
495
  """
451
496
  Initializes the preconditioner matrices (L and R in the paper).
452
497
  """
453
- state['Q'] = None # Will hold all the eigenbases of the preconditioner.
454
498
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
455
499
  if grad.dim() == 1:
456
- if not precondition_1d or grad.shape[0] > max_precond_dim:
500
+ if precondition_1d or grad.shape[0] > max_precond_dim:
501
+ state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
502
+ else:
457
503
  state['GG'].append([])
458
- return
459
- state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
460
- return
461
504
 
462
- for sh in grad.shape:
463
- if sh > max_precond_dim:
464
- state['GG'].append([])
465
- else:
466
- state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
505
+ else:
506
+ for sh in grad.shape:
507
+ if sh > max_precond_dim:
508
+ state['GG'].append([])
509
+ else:
510
+ state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
511
+
512
+ compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
513
+ state['Q'] = get_orthogonal_matrix(state['GG'])
467
514
 
468
515
 
469
516
  @decorator
@@ -629,84 +676,178 @@ class StatefulOptimizer(torch.optim.Optimizer):
629
676
  return loss
630
677
 
631
678
 
632
-
633
- class ScheduleFree(StatefulOptimizer):
634
- def eval(self):
635
- for group in self.param_groups:
636
- train_mode = group['train_mode']
637
- beta1 = group['beta'] if 'beta' in group else group['betas'][0]
638
- if beta1 > 0 and train_mode:
639
- for p in group['params']:
640
- state = self.state_(p)
641
- if 'z' in state:
642
- # Set p.data to x
643
- z = promote(state['z'])
644
- p32 = promote(p.data)
645
- p32.lerp_(end=z, weight=1 - 1 / beta1)
646
- copy_stochastic_(p.data, p32)
647
- group['train_mode'] = False
648
-
649
- def train(self):
650
- for group in self.param_groups:
651
- train_mode = group['train_mode']
652
- beta1 = group['beta'] if 'beta' in group else group['betas'][0]
653
- if beta1 > 0 and not train_mode:
654
- for p in group['params']:
655
- state = self.state_(p)
656
- if 'z' in state:
657
- z = promote(state['z'])
658
- p32 = promote(p.data)
659
- p32.lerp_(end=z, weight=1 - beta1)
660
- copy_stochastic_(p.data, p32)
661
- group['train_mode'] = True
662
-
663
- def _step(self):
664
- raise NotImplementedError
665
-
666
-
667
679
  def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
668
680
  for t, s in zip(target, source):
669
681
  copy_stochastic_(t, s)
670
682
 
671
683
 
672
684
  @decorator_knowngood
673
- def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
674
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
685
+ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
686
+ step: Tensor):
687
+ beta1 = beta_debias(beta1, step)
688
+ beta2 = beta_debias(beta2, step)
689
+
690
+ g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
691
+
692
+ [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
693
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
694
+ u32 = torch._foreach_div(exp_avg32, denom)
695
+
696
+ copy_stochastic_list_(exp_avg, exp_avg32)
697
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
698
+ return stochastic_round_list_(exp_avg, u32)
699
+
700
+
701
+ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
702
+ exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
703
+ beta1, beta2, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step,
704
+ exp_avg[0])
705
+ return _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
706
+
707
+
708
+ @decorator_knowngood
709
+ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
710
+ beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
711
+ caution: bool):
712
+ beta1 = beta_debias(beta1, step)
713
+ beta2 = beta_debias(beta2, step)
714
+
715
+ g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
716
+
717
+ [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
718
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
719
+ u32 = torch._foreach_div(exp_avg32, denom)
720
+
721
+ copy_stochastic_list_(exp_avg, exp_avg32)
722
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
723
+ _compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
724
+
725
+
726
+ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
727
+ beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
728
+ y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad))
729
+ beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
730
+ return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
731
+
732
+
733
+ @decorator_knowngood
734
+ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: Tensor,
735
+ beta2: Tensor, step: Tensor):
675
736
  beta1 = beta_debias(beta1, step)
676
737
  beta2 = beta_debias(beta2, step)
677
738
 
678
- g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
739
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
679
740
 
680
- stochastic_lerp_(exp_avg, g32, 1 - beta1)
681
741
  denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
742
+ gp32 = torch._foreach_div(gp32, denom)
743
+ stochastic_lerp_(exp_avg, gp32, 1 - beta1)
682
744
 
683
745
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
684
- return denom
685
746
 
686
747
 
687
- def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
688
- beta1: float, beta2: float, step: int):
689
- exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
690
- grad), list_guard(grad_projected)
748
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: float, beta2: float,
749
+ step: int):
750
+ exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
691
751
  beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
692
- denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
693
- return denom
752
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
753
+ return exp_avg
694
754
 
695
755
 
696
756
  @decorator_knowngood
697
- def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
698
- """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
699
- # create a random 16 bit integer
700
- result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
757
+ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
758
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
759
+ decay: Tensor, caution: bool):
760
+ beta1 = beta_debias(beta1, step)
761
+ beta2 = beta_debias(beta2, step)
701
762
 
702
- # add the random number to the lower 16 bit of the mantissa
703
- result.add_(source.view(dtype=torch.int32))
763
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
764
+
765
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
766
+ gp32 = torch._foreach_div(gp32, denom)
767
+ stochastic_lerp_(exp_avg, gp32, 1 - beta1)
768
+ update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
769
+
770
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
771
+
772
+
773
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
774
+ beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
775
+ y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected))
776
+ beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
777
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution)
778
+
779
+
780
+ @decorator_knowngood
781
+ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
782
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
783
+ update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
784
+
785
+ beta1 = beta_debias(beta1, step)
786
+ denom = torch._foreach_sqrt(exp_avg_sq32)
787
+ [denom.clamp_(min=eps) for denom in denom]
788
+ torch._foreach_mul_(exp_avg32, beta1)
789
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
790
+
791
+ beta2 = beta_debias(beta2, step + 1)
792
+ torch._foreach_mul_(exp_avg_sq32, beta2)
793
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
794
+
795
+ copy_stochastic_list_(exp_avg, exp_avg32)
796
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
797
+
798
+
799
+ def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
800
+ y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
801
+ beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
802
+ _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
803
+
804
+
805
+ @decorator_knowngood
806
+ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
807
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
808
+ update = [e.clone() for e in exp_avg]
809
+
810
+ beta1 = beta_debias(beta1, step)
811
+ denom = torch._foreach_sqrt(exp_avg_sq32)
812
+ [denom.clamp_(min=1e-8) for denom in denom]
813
+ torch._foreach_mul_(exp_avg32, beta1)
814
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
815
+
816
+ beta2 = beta_debias(beta2, step + 1)
817
+ torch._foreach_mul_(exp_avg_sq32, beta2)
818
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
819
+
820
+ copy_stochastic_list_(exp_avg, exp_avg32)
821
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
822
+
823
+ return update
824
+
825
+
826
+ def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
827
+ grad, exp_avg_sq, exp_avg = list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
828
+ beta1, beta2, step = scalar_guard(beta1, grad[0]), scalar_guard(beta2, grad[0]), scalar_guard(step, grad[0])
829
+ return _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
830
+
831
+
832
+ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
833
+ return [stochastic_round_(r, s) for r, s in zip(ref, source)]
704
834
 
705
- # mask off the lower 16 bit of the mantissa
835
+
836
+ @decorator_knowngood
837
+ def stochastic_round_(ref: Tensor, source: Tensor):
838
+ if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
839
+ return source
840
+ if ref.dtype != torch.bfloat16:
841
+ return source.to(ref.dtype)
842
+ result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
843
+ result.add_(source.view(dtype=torch.int32))
706
844
  result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
845
+ return result.view(dtype=torch.float32).bfloat16()
707
846
 
708
- # copy the higher 16 bit into the target tensor
709
- target.copy_(result.view(dtype=torch.float32))
847
+
848
+ @decorator_knowngood
849
+ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
850
+ target.copy_(stochastic_round_(target, source))
710
851
 
711
852
 
712
853
  def copy_stochastic_(target: Tensor, source: Tensor):
@@ -879,7 +1020,7 @@ def psgd_lb(A, max_abs):
879
1020
 
880
1021
 
881
1022
  @decorator
882
- def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
1023
+ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
883
1024
  """Update Kronecker product preconditioner Q with pair (V, G)."""
884
1025
  exprA, exprGs, _ = exprs
885
1026
 
@@ -900,10 +1041,10 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
900
1041
  norm = term2.norm(float('inf'))
901
1042
  if q.dim() < 2:
902
1043
  term1 *= q.to(term1.dtype)
903
- term1 /= norm.clamp_(min=tiny)
1044
+ term1 /= norm.clamp_(min=tiny_bf16)
904
1045
  else:
905
1046
  torch.triu(term1, out=term1)
906
- term1 /= psgd_lb(term2, norm).clamp_(tiny)
1047
+ term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
907
1048
  torch.matmul(term1, q, out=term1)
908
1049
  if store_triu_as_line:
909
1050
  term1 = triu_to_line([term1])[0][1]
@@ -912,22 +1053,32 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
912
1053
 
913
1054
 
914
1055
  @decorator_knowngood
915
- def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
916
- """Precondition gradient G with preconditioner Q."""
917
- md = min_dtype(preconds)
918
- out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
919
- if inplace:
920
- set_(grad, out)
921
- return grad
922
- return out.to(grad.dtype)
1056
+ def _compilable_l2_clip_(x):
1057
+ ref = x
1058
+ x = list(map(promote, x))
1059
+ norm = torch._foreach_norm(x)
1060
+ torch._foreach_maximum_(norm, 1e-8)
1061
+ out = torch._foreach_div(x, norm)
1062
+ return stochastic_round_list_(ref, out)
1063
+
923
1064
 
1065
+ def l2_clip_(x):
1066
+ x = list_guard(x)
1067
+ return _compilable_l2_clip_(x)
924
1068
 
925
- def norm_clip_(x, scale=None):
1069
+
1070
+ @decorator_knowngood
1071
+ def _compilable_rmsnorm_clip_(x):
1072
+ x = list(map(promote, x))
926
1073
  norm = torch._foreach_norm(x)
927
- if scale is not None:
928
- torch._foreach_div_(norm, scale)
929
- torch._foreach_div_(x, norm)
930
- return x
1074
+ norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1075
+ torch._foreach_maximum_(norm, 1e-6)
1076
+ return torch._foreach_div(x, norm)
1077
+
1078
+
1079
+ def rmsnorm_clip_(x):
1080
+ x = list_guard(x)
1081
+ return _compilable_rmsnorm_clip_(x)
931
1082
 
932
1083
 
933
1084
  def mu_law_compress(x, mu=127.0):
@@ -967,18 +1118,24 @@ def identity(x):
967
1118
  return x
968
1119
 
969
1120
 
970
- def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
971
- torch._foreach_mul_(grad, 1 / scale)
972
- tanh = torch._foreach_tanh(grad)
973
- torch._foreach_abs_(grad)
974
- torch._foreach_log1p_(grad)
975
- grad = [p.copysign_(t) for t, p in zip(tanh, grad)] # torch doesn't have a foreach copysign
976
- torch._foreach_lerp_(grad, tanh, lerp) # sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9
977
- torch._foreach_mul_(grad, scale)
1121
+ @decorator_knowngood
1122
+ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1123
+ g32 = list(map(promote, grad))
1124
+ [g.mul_(1 / scale) for g in g32]
1125
+ tanh = torch._foreach_tanh(g32)
1126
+ torch._foreach_abs_(g32)
1127
+ torch._foreach_log1p_(g32)
1128
+ [g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
978
1129
 
979
- torch._foreach_maximum_(grad, -2)
980
- torch._foreach_minimum_(grad, 2)
981
- return grad
1130
+ torch._foreach_maximum_(g32, -2)
1131
+ torch._foreach_minimum_(g32, 2)
1132
+ return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
1133
+
1134
+
1135
+ def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1136
+ grad = list_guard(grad)
1137
+ lerp, scale = scalar_guard(lerp, grad[0]), scalar_guard(scale, grad[0])
1138
+ return _compilable_trust_region_clip_(grad, lerp, scale)
982
1139
 
983
1140
 
984
1141
  @decorator
@@ -1017,60 +1174,57 @@ def update_triu_(q_state, materialised):
1017
1174
  copy_stochastic_(q, m)
1018
1175
 
1019
1176
 
1020
- class PSGDBase(StatefulOptimizer):
1021
- balance_probability: float = 0.01
1022
-
1023
- def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
1024
- preconditioner_update_probability):
1025
- super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
1026
- self.rng = random.Random(0x1923213)
1027
- self._tiny = torch.finfo(torch.bfloat16).tiny
1028
- if clip_fn is None:
1029
- clip_fn = identity
1030
- if preconditioner_update_probability is None:
1031
- preconditioner_update_probability = precond_update_prob_schedule()
1032
- self.clip_fn = clip_fn
1033
- self.preconditioner_update_probability = preconditioner_update_probability
1034
-
1035
- def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
1036
- group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1037
- if prob is None:
1038
- prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
1039
- if group['stochastic_schedule']:
1040
- return self.rng.random() < prob
1041
- cumulative_prob = group.get(name, 0)
1042
- group[name] = cumulative_prob + prob
1043
- return int(group[name]) > int(cumulative_prob)
1044
-
1045
- def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
1046
- for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
1047
- psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
1048
-
1049
- if self.should_update(group, self.balance_probability, "balance_prob"):
1050
- for g, q in zip(grad_list, original_q if original_q else q_list):
1051
- if g.dim() > 1:
1052
- if store_triu_as_line:
1053
- psgd_balance_Q([q_ for _, q_ in q])
1054
- else:
1055
- psgd_balance_Q(q)
1056
-
1057
-
1058
- # TODO: Figure out why this sometimes crashes
1059
- # @decorator_knowngood
1060
- def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
1061
- clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
1177
+ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1178
+ name: str = 'cumulative_prob'):
1179
+ group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1180
+ if not isinstance(prob, float):
1181
+ prob = prob(group[f'{name}_prob_step'])
1182
+ if group['stochastic_schedule']:
1183
+ return rng.random() < prob
1184
+ cumulative_prob = state.get(name, 0)
1185
+ group[name] = cumulative_prob + prob
1186
+ return int(group[name]) > int(cumulative_prob)
1187
+
1188
+
1189
+ @decorator_knowngood
1190
+ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1062
1191
  md = min_dtype(list(cached_q) + [ea])
1063
1192
  args = [q.to(md) for q in cached_q]
1064
1193
  args = args + [ea.to(md)]
1065
1194
  new = torch.einsum(expr, *args)
1066
- new = new.to(torch.float32)
1067
- _compilable_update_([param], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
1195
+ if cast:
1196
+ return new.to(ea.dtype)
1197
+ return new
1198
+
1068
1199
 
1200
+ @decorator_knowngood
1201
+ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1202
+ precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1203
+ update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1204
+
1205
+
1206
+ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1207
+ lr = scalar_guard(lr, param[0])
1208
+ _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
1069
1209
 
1070
- def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
1071
- clip_fn, caution, grad):
1072
- lr = scalar_guard(lr, param)
1073
- _compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
1210
+
1211
+ @decorator_knowngood
1212
+ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1213
+ md = min_dtype(list(preconds) + [ea])
1214
+ args = [q.to(md) for q in preconds]
1215
+ args = args + args + [ea.to(md)]
1216
+ new = torch.einsum(expr, *args)
1217
+ return new.to(ea.dtype)
1218
+
1219
+
1220
+ def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1221
+ precond = psgd_precond_grad(expr, grad, *preconds)
1222
+ update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1223
+
1224
+
1225
+ def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1226
+ lr = scalar_guard(lr, param[0])
1227
+ _compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
1074
1228
 
1075
1229
 
1076
1230
  @decorator_knowngood
@@ -1099,7 +1253,7 @@ def caution(g, update):
1099
1253
  _compilable_cautioning_(g, update)
1100
1254
 
1101
1255
 
1102
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
1256
+ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
1103
1257
  """Anneal preconditioner update probability during beginning of training.
1104
1258
 
1105
1259
  PSGD benefits from more preconditioner updates at the beginning of training,