heavyball 0.25.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,3 +1,11 @@
1
+ """
2
+
3
+
4
+ Originally from Evan Walters and Omead Pooladzandi, 2024
5
+ Modified under Creative Commons Attribution 4.0 International
6
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
7
+ """
8
+
1
9
  import functools
2
10
  import gc
3
11
  import math
@@ -15,7 +23,8 @@ from torch.utils._pytree import tree_map
15
23
  compile_mode = "max-autotune-no-cudagraphs"
16
24
  dynamic = False
17
25
  compile_mode_recommended_to_none = None
18
- zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
26
+ zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
27
+ tiny_bf16 = torch.finfo(torch.bfloat16).tiny
19
28
 
20
29
 
21
30
  def decorator(func):
@@ -60,30 +69,22 @@ def warmup(lr: float, step: int, warmup_steps: int):
60
69
 
61
70
  @decorator_knowngood
62
71
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
63
- beta1: Tensor):
72
+ beta1: Tensor, decay: float):
73
+ grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
64
74
  p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
65
75
  for p_, z_, g_ in zip(p32, z32, g32):
76
+ if decay != 0:
77
+ g_.add_(p_, alpha=decay)
66
78
  p_.lerp_(z_, ckp1)
67
- p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
68
- z_.add_(g_, alpha=-lr)
79
+ p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
80
+ z_.add_(g_, alpha=lr)
69
81
  copy_stochastic_list_(p, p32)
70
82
  copy_stochastic_list_(z, z32)
71
83
 
72
84
 
73
- def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
74
- weight = lr ** weight_lr_power * max(step, 1) ** r
75
- weight_sum = weight_sum + weight
76
-
77
- try:
78
- ckp1 = weight / weight_sum
79
- except ZeroDivisionError:
80
- ckp1 = 0
81
- return ckp1, weight_sum
82
-
83
-
84
85
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
85
- z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
86
- weight = lr ** weight_lr_power * max(step, 1) ** r
86
+ z: List[Tensor], grad: List[Tensor], r: float = 0.0, step: int = 0, decay: float = 0.0):
87
+ weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
87
88
  weight_sum = weight_sum + weight
88
89
 
89
90
  try:
@@ -91,10 +92,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
91
92
  except ZeroDivisionError:
92
93
  ckp1 = 0
93
94
 
94
- # These operations update y in-place,
95
- # without computing x explicitly.
96
- lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
97
- _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
95
+ grad, parameters, z = list_guard(grad, parameters, z)
96
+ lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
97
+ _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
98
98
  return weight_sum
99
99
 
100
100
 
@@ -162,10 +162,13 @@ def beta_debias(beta, step):
162
162
  @decorator_knowngood
163
163
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
164
164
  out: List[Optional[Tensor]]):
165
- torch._foreach_mul_(state, beta2)
166
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
167
- denom = torch._foreach_sqrt(state)
168
- [denom.clamp_(min=eps) for denom in denom]
165
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
166
+ torch._foreach_mul_(s32, beta2)
167
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
168
+ denom = torch._foreach_sqrt(s32)
169
+ [d.clamp_(min=eps) for d in denom]
170
+ copy_stochastic_list_(state, s32)
171
+
169
172
  if out[0] is None:
170
173
  return denom
171
174
 
@@ -174,15 +177,32 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
174
177
 
175
178
 
176
179
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
177
- state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
178
- beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
180
+ state, grad, out = list_guard(state, grad, out)
181
+ beta2, eps = scalar_guard(beta2, eps, state[0])
179
182
  return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
180
183
 
181
184
 
182
- def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
183
- minimum: float = 1e-3, eps: float = 1e-8):
184
- if clip_val <= 0:
185
- return
185
+ @decorator_knowngood
186
+ def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
187
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
188
+ torch._foreach_mul_(s32, beta2)
189
+ [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
190
+ denom = torch._foreach_sqrt(s32)
191
+ [d.clamp_(min=eps) for d in denom]
192
+ out = torch._foreach_div_(g32, denom)
193
+ copy_stochastic_list_(state, s32)
194
+ copy_stochastic_list_(grad, out)
195
+
196
+
197
+ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
198
+ grad, exp_avg_sq = list_guard(grad, exp_avg_sq)
199
+ beta2, eps = scalar_guard(beta2, eps, grad[0])
200
+ _compilable_scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps)
201
+ return grad
202
+
203
+
204
+ @decorator_knowngood
205
+ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
186
206
  p_norm = torch._foreach_norm(parameters)
187
207
  g_norm = torch._foreach_norm(gradients)
188
208
  torch._foreach_maximum_(p_norm, minimum)
@@ -190,7 +210,16 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
190
210
  torch._foreach_div_(p_norm, g_norm)
191
211
  torch._foreach_mul_(p_norm, clip_val)
192
212
  torch._foreach_minimum_(p_norm, 1)
193
- torch._foreach_mul_(gradients, p_norm)
213
+ return torch._foreach_mul(gradients, p_norm)
214
+
215
+
216
+ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
217
+ minimum: float = 1e-3, eps: float = 1e-8):
218
+ if clip_val <= 0:
219
+ return gradients
220
+ parameters, gradients = list_guard(parameters, gradients)
221
+ clip_val = scalar_guard(clip_val, parameters[0])
222
+ return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
194
223
 
195
224
 
196
225
  def is_compiling():
@@ -205,10 +234,7 @@ def set_(dst: Tensor, src: Tensor):
205
234
  return
206
235
  if src.shape != dst.shape:
207
236
  src = src.reshape_as(dst)
208
- if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
209
- dst.set_(src)
210
- else:
211
- dst.copy_(src)
237
+ dst.copy_(src)
212
238
 
213
239
 
214
240
  def clean():
@@ -226,33 +252,29 @@ def set_torch():
226
252
 
227
253
 
228
254
  @decorator
229
- def zeropower_via_newtonschulz5(G, init, steps=2, eps=1e-7):
255
+ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
230
256
  """
231
- Modified from "modded-nanogpt" under the MIT license:
232
- Original: https://github.com/KellerJordan/modded-nanogpt/blob/a0dcbfdd9a0617d091d5123cfc354745428e40d3/train_gpt2.py
233
-
234
257
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
235
258
  quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
236
259
  of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
237
260
  zero even beyond the point where the iteration no longer converges all the way to one everywhere
238
261
  on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
239
- where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
262
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
240
263
  performance at all relative to UV^T, where USV^T = G is the SVD.
241
264
  """
242
265
  assert len(G.shape) == 2
243
266
  a, b, c = (3.4445, -4.7750, 2.0315)
244
- X = G.float()
245
- init = init / (init.norm() + eps) # ensure top singular value <= 1
246
- X = X / (X.norm() + eps) # ensure top singular value <= 1
267
+ X = G.bfloat16()
268
+ X /= (X.norm() + eps) # ensure top singular value <= 1
247
269
  if G.size(0) > G.size(1):
248
270
  X = X.T
249
271
  for _ in range(steps):
250
- A = X @ X.T # preconditioner
251
- B = A @ init
252
- init = X = a * init + b * B + c * A @ B
272
+ A = X @ X.T
273
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
274
+ X = a * X + B @ X
253
275
  if G.size(0) > G.size(1):
254
276
  X = X.T
255
- return X
277
+ return X.to(G.dtype)
256
278
 
257
279
 
258
280
  def ortho(x):
@@ -264,6 +286,53 @@ def ortho(x):
264
286
  raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
265
287
 
266
288
 
289
+ @decorator_knowngood
290
+ def _compilable_heavyball_momentum_(state, grad, beta):
291
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
292
+ torch._foreach_mul_(s32, beta)
293
+ torch._foreach_add_(s32, g32)
294
+ copy_stochastic_list_(state, s32)
295
+ copy_stochastic_list_(grad, s32)
296
+
297
+
298
+ @decorator_knowngood
299
+ def _compilable_nesterov_momentum_(state, grad, beta):
300
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
301
+ torch._foreach_mul_(s32, beta)
302
+ torch._foreach_add_(s32, g32)
303
+ [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
304
+ copy_stochastic_list_(state, s32)
305
+ copy_stochastic_list_(grad, g32)
306
+
307
+
308
+ def heavyball_momentum(state, grad, beta):
309
+ state, grad = list_guard(state, grad)
310
+ beta = scalar_guard(beta, state[0])
311
+ _compilable_heavyball_momentum_(state, grad, beta)
312
+ return grad
313
+
314
+
315
+ def nesterov_momentum(state, grad, beta):
316
+ state, grad = list_guard(state, grad)
317
+ beta = scalar_guard(beta, state[0])
318
+ _compilable_nesterov_momentum_(state, grad, beta)
319
+ return grad
320
+
321
+
322
+ @decorator_knowngood
323
+ def inplace_orthogonal_(x, mode, out):
324
+ if mode == 'qr':
325
+ y = torch.linalg.qr(x).Q
326
+ elif mode == 'svd':
327
+ u, s, v = torch.linalg.svd(x)
328
+ y = u @ v.T
329
+ elif mode == 'newtonschulz':
330
+ y = zeropower_via_newtonschulz5(x, 5)
331
+ else:
332
+ raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
333
+ set_(out, y)
334
+
335
+
267
336
  def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
268
337
  """
269
338
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -294,17 +363,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
294
363
  est_eig = torch.einsum('ij,ij->j', o, tmp)
295
364
  sort_idx = torch.argsort(est_eig, descending=True)
296
365
  indices.append(sort_idx)
297
- if zeroth_power_mode == 'eigh':
298
- set_(q, torch.linalg.eigh(m)[1])
299
- elif zeroth_power_mode.startswith('newtonschulz'):
300
- iterations = zeroth_power_mode[len('newtonschulz'):]
301
- if iterations == '':
302
- iterations = 10
303
- else:
304
- iterations = int(iterations)
305
- set_(q, zeropower_via_newtonschulz5(m, o[:, sort_idx], iterations))
306
- else:
307
- set_(q, ortho(tmp[:, sort_idx]))
366
+ inplace_orthogonal_(tmp[:, sort_idx], q)
308
367
 
309
368
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
310
369
  for i, ind in enumerate(indices))
@@ -353,8 +412,6 @@ def get_orthogonal_matrix(mat):
353
412
 
354
413
  Q = torch.flip(Q, [1])
355
414
 
356
- if not float_data:
357
- Q = Q.to(original_device).type(original_type)
358
415
  final.append(Q)
359
416
 
360
417
  return final
@@ -369,24 +426,57 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
369
426
  copy_stochastic_(x_, x32)
370
427
 
371
428
 
429
+ def get_beta1(group):
430
+ beta = None
431
+ if 'beta' in group:
432
+ beta = group['beta']
433
+ if beta is None and 'betas' in group:
434
+ beta = group['betas'][0]
435
+ if beta is None:
436
+ raise ValueError("Beta not found in group.")
437
+ return beta
438
+
439
+
440
+ def get_beta2(group):
441
+ if 'beta2_scale' in group:
442
+ step = max(group.get("step", 1), 1)
443
+ return 1 - step ** -group['beta2_scale']
444
+ if 'betas' in group:
445
+ return group['betas'][1]
446
+ raise ValueError("Beta2 not found in group.")
447
+
448
+
372
449
  def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
373
- x, y = list_guard(x), list_guard(y)
450
+ x, y = list_guard(x, y)
374
451
  a = scalar_guard(a, x[0])
375
452
  _compilable_stochastic_lerp_(x, y, a)
376
453
 
377
454
 
378
- def list_guard(x):
379
- if isinstance(x, (list, tuple)):
380
- return x
381
- return [x]
455
+ def list_guard(*xs):
456
+ out = []
457
+ for x in xs:
458
+ if isinstance(x, (list, tuple)):
459
+ out.append(x)
460
+ else:
461
+ out.append([x])
462
+ if len(xs) == 1:
463
+ return out[0]
464
+ return out
382
465
 
383
466
 
384
- def scalar_guard(x, ref):
385
- if isinstance(x, float):
386
- return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
387
- if isinstance(x, int):
388
- return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
389
- return x
467
+ def scalar_guard(*args):
468
+ *xs, ref = args
469
+ out = []
470
+ for x in xs:
471
+ if isinstance(x, float):
472
+ out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
473
+ elif isinstance(x, int):
474
+ out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
475
+ else:
476
+ out.append(x)
477
+ if len(xs) == 1:
478
+ return out[0]
479
+ return out
390
480
 
391
481
 
392
482
  @decorator_knowngood
@@ -399,7 +489,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
399
489
 
400
490
 
401
491
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
402
- x, y = list_guard(x), list_guard(y)
492
+ x, y = list_guard(x, y)
403
493
  alpha = scalar_guard(alpha, x[0])
404
494
  _compilable_stochastic_add_(x, y, alpha)
405
495
 
@@ -435,35 +525,35 @@ def min_dtype(xs: List[Tensor]):
435
525
  return torch.float32
436
526
 
437
527
 
438
- def update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond):
528
+ def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
439
529
  """
440
530
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
441
531
  """
442
- compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
443
- if state['Q'] is None:
444
- state['Q'] = get_orthogonal_matrix(state['GG'])
532
+ compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
445
533
  if update_precond:
446
- get_orthogonal_matrix_QR(state['GG'], state['Q'], state['exp_avg_sq'])
534
+ get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
447
535
 
448
536
 
449
- def init_preconditioner(grad, state, max_precond_dim=10000, precondition_1d=False):
537
+ def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
450
538
  """
451
539
  Initializes the preconditioner matrices (L and R in the paper).
452
540
  """
453
- state['Q'] = None # Will hold all the eigenbases of the preconditioner.
454
541
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
455
542
  if grad.dim() == 1:
456
- if not precondition_1d or grad.shape[0] > max_precond_dim:
543
+ if precondition_1d or grad.shape[0] > max_precond_dim:
544
+ state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
545
+ else:
457
546
  state['GG'].append([])
458
- return
459
- state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
460
- return
461
547
 
462
- for sh in grad.shape:
463
- if sh > max_precond_dim:
464
- state['GG'].append([])
465
- else:
466
- state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
548
+ else:
549
+ for sh in grad.shape:
550
+ if sh > max_precond_dim:
551
+ state['GG'].append([])
552
+ else:
553
+ state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
554
+
555
+ compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
556
+ state['Q'] = get_orthogonal_matrix(state['GG'])
467
557
 
468
558
 
469
559
  @decorator
@@ -629,74 +719,87 @@ class StatefulOptimizer(torch.optim.Optimizer):
629
719
  return loss
630
720
 
631
721
 
632
-
633
- class ScheduleFree(StatefulOptimizer):
634
- def eval(self):
635
- for group in self.param_groups:
636
- train_mode = group['train_mode']
637
- beta1 = group['beta'] if 'beta' in group else group['betas'][0]
638
- if beta1 > 0 and train_mode:
639
- for p in group['params']:
640
- state = self.state_(p)
641
- if 'z' in state:
642
- # Set p.data to x
643
- z = promote(state['z'])
644
- p32 = promote(p.data)
645
- p32.lerp_(end=z, weight=1 - 1 / beta1)
646
- copy_stochastic_(p.data, p32)
647
- group['train_mode'] = False
648
-
649
- def train(self):
650
- for group in self.param_groups:
651
- train_mode = group['train_mode']
652
- beta1 = group['beta'] if 'beta' in group else group['betas'][0]
653
- if beta1 > 0 and not train_mode:
654
- for p in group['params']:
655
- state = self.state_(p)
656
- if 'z' in state:
657
- z = promote(state['z'])
658
- p32 = promote(p.data)
659
- p32.lerp_(end=z, weight=1 - beta1)
660
- copy_stochastic_(p.data, p32)
661
- group['train_mode'] = True
662
-
663
- def _step(self):
664
- raise NotImplementedError
665
-
666
-
667
722
  def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
668
723
  for t, s in zip(target, source):
669
724
  copy_stochastic_(t, s)
670
725
 
671
726
 
672
727
  @decorator_knowngood
673
- def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
674
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
728
+ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
729
+ step: Tensor):
675
730
  beta1 = beta_debias(beta1, step)
676
731
  beta2 = beta_debias(beta2, step)
677
732
 
678
- g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
733
+ g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
679
734
 
680
- stochastic_lerp_(exp_avg, g32, 1 - beta1)
681
- denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
735
+ [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
736
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
737
+ u32 = torch._foreach_div(exp_avg32, denom)
738
+
739
+ copy_stochastic_list_(exp_avg, exp_avg32)
740
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
741
+ copy_stochastic_list_(grad, u32)
742
+
743
+
744
+ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
745
+ exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
746
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
747
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
748
+ return grad
682
749
 
750
+
751
+ @decorator_knowngood
752
+ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
753
+ beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
754
+ caution: bool):
755
+ beta1 = beta_debias(beta1, step)
756
+ beta2 = beta_debias(beta2, step)
757
+
758
+ g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
759
+
760
+ [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
761
+ denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
762
+ u32 = torch._foreach_div(exp_avg32, denom)
763
+
764
+ copy_stochastic_list_(exp_avg, exp_avg32)
683
765
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
684
- return denom
766
+ _compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
767
+
768
+
769
+ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
770
+ beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
771
+ y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
772
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
773
+ return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
774
+
685
775
 
776
+ @decorator_knowngood
777
+ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
778
+ beta2: Tensor, step: Tensor):
779
+ beta1 = beta_debias(beta1, step)
780
+ beta2 = beta_debias(beta2, step)
781
+
782
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
783
+
784
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
785
+ gp32 = torch._foreach_div(gp32, denom)
786
+ stochastic_lerp_(exp_avg, gp32, 1 - beta1)
787
+
788
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
789
+ copy_stochastic_list_(grad, exp_avg)
686
790
 
687
- def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
688
- beta1: float, beta2: float, step: int):
689
- exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
690
- grad), list_guard(grad_projected)
691
- beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
692
- denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
693
- return denom
694
791
 
792
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
793
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
794
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
795
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
796
+ return grad
695
797
 
696
798
 
697
799
  @decorator_knowngood
698
- def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
699
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
800
+ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
801
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
802
+ decay: Tensor, caution: bool):
700
803
  beta1 = beta_debias(beta1, step)
701
804
  beta2 = beta_debias(beta2, step)
702
805
 
@@ -705,31 +808,89 @@ def _compilable_laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
705
808
  denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
706
809
  gp32 = torch._foreach_div(gp32, denom)
707
810
  stochastic_lerp_(exp_avg, gp32, 1 - beta1)
811
+ update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
708
812
 
709
813
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
710
814
 
711
815
 
712
- def laprop_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
713
- beta1: float, beta2: float, step: int):
714
- exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
715
- beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
716
- _compilable_laprop_exp_avg_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
816
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
817
+ beta2: float, step: int, lr: float, decay: float, caution: bool):
818
+ exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
819
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
820
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
717
821
 
718
822
 
719
823
  @decorator_knowngood
720
- def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
721
- """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
722
- # create a random 16 bit integer
723
- result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
824
+ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
825
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
826
+ update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
827
+
828
+ beta1 = beta_debias(beta1, step)
829
+ denom = torch._foreach_sqrt(exp_avg_sq32)
830
+ [denom.clamp_(min=eps) for denom in denom]
831
+ torch._foreach_mul_(exp_avg32, beta1)
832
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
833
+
834
+ beta2 = beta_debias(beta2, step + 1)
835
+ torch._foreach_mul_(exp_avg_sq32, beta2)
836
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
837
+
838
+ copy_stochastic_list_(exp_avg, exp_avg32)
839
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
840
+
841
+
842
+ def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
843
+ exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
844
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
845
+ _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
846
+
847
+
848
+ @decorator_knowngood
849
+ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
850
+ g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
851
+ update = [e.clone() for e in exp_avg]
852
+
853
+ beta1 = beta_debias(beta1, step)
854
+ denom = torch._foreach_sqrt(exp_avg_sq32)
855
+ [denom.clamp_(min=1e-8) for denom in denom]
856
+ torch._foreach_mul_(exp_avg32, beta1)
857
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
858
+
859
+ beta2 = beta_debias(beta2, step + 1)
860
+ torch._foreach_mul_(exp_avg_sq32, beta2)
861
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
862
+
863
+ copy_stochastic_list_(exp_avg, exp_avg32)
864
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
865
+ copy_stochastic_list_(grad, update)
724
866
 
725
- # add the random number to the lower 16 bit of the mantissa
726
- result.add_(source.view(dtype=torch.int32))
727
867
 
728
- # mask off the lower 16 bit of the mantissa
868
+ def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
869
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
870
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
871
+ _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
872
+ return grad
873
+
874
+
875
+ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
876
+ return [stochastic_round_(r, s) for r, s in zip(ref, source)]
877
+
878
+
879
+ @decorator_knowngood
880
+ def stochastic_round_(ref: Tensor, source: Tensor):
881
+ if source.dtype == torch.bfloat16 or ref.dtype == source.dtype:
882
+ return source
883
+ if ref.dtype != torch.bfloat16:
884
+ return source.to(ref.dtype)
885
+ result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
886
+ result.add_(source.view(dtype=torch.int32))
729
887
  result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
888
+ return result.view(dtype=torch.float32).bfloat16()
730
889
 
731
- # copy the higher 16 bit into the target tensor
732
- target.copy_(result.view(dtype=torch.float32))
890
+
891
+ @decorator_knowngood
892
+ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
893
+ target.copy_(stochastic_round_(target, source))
733
894
 
734
895
 
735
896
  def copy_stochastic_(target: Tensor, source: Tensor):
@@ -759,7 +920,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
759
920
 
760
921
  def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
761
922
  caution: bool = False, grad: List[Tensor] = None):
762
- param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
923
+ param, update, grad = list_guard(param, update, grad)
763
924
  lr = scalar_guard(lr, param[0])
764
925
  if not caution:
765
926
  grad = [None] * len(param)
@@ -865,11 +1026,15 @@ def psgd_balance_Q(Q_in):
865
1026
 
866
1027
 
867
1028
  def psgd_calc_A_and_conjB(exprA, G, Q):
1029
+ V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
1030
+ eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1031
+ eps *= G.norm() / G.numel()
1032
+ G += V * eps
868
1033
  md = min_dtype(Q + [G])
869
1034
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
870
1035
  order = G.dim()
871
1036
  p = list(range(order))
872
- conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1037
+ conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
873
1038
  Q = [promote(q) for q in Q]
874
1039
  for i, q in enumerate(Q):
875
1040
  if q.dim() <= 1:
@@ -902,7 +1067,7 @@ def psgd_lb(A, max_abs):
902
1067
 
903
1068
 
904
1069
  @decorator
905
- def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
1070
+ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
906
1071
  """Update Kronecker product preconditioner Q with pair (V, G)."""
907
1072
  exprA, exprGs, _ = exprs
908
1073
 
@@ -923,10 +1088,10 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
923
1088
  norm = term2.norm(float('inf'))
924
1089
  if q.dim() < 2:
925
1090
  term1 *= q.to(term1.dtype)
926
- term1 /= norm.clamp_(min=tiny)
1091
+ term1 /= norm.clamp_(min=tiny_bf16)
927
1092
  else:
928
1093
  torch.triu(term1, out=term1)
929
- term1 /= psgd_lb(term2, norm).clamp_(tiny)
1094
+ term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
930
1095
  torch.matmul(term1, q, out=term1)
931
1096
  if store_triu_as_line:
932
1097
  term1 = triu_to_line([term1])[0][1]
@@ -935,22 +1100,32 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
935
1100
 
936
1101
 
937
1102
  @decorator_knowngood
938
- def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
939
- """Precondition gradient G with preconditioner Q."""
940
- md = min_dtype(preconds)
941
- out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
942
- if inplace:
943
- set_(grad, out)
944
- return grad
945
- return out.to(grad.dtype)
1103
+ def _compilable_l2_clip_(x):
1104
+ ref = x
1105
+ x = list(map(promote, x))
1106
+ norm = torch._foreach_norm(x)
1107
+ torch._foreach_maximum_(norm, 1e-8)
1108
+ out = torch._foreach_div(x, norm)
1109
+ return stochastic_round_list_(ref, out)
1110
+
946
1111
 
1112
+ def l2_clip_(x):
1113
+ x = list_guard(x)
1114
+ return _compilable_l2_clip_(x)
947
1115
 
948
- def norm_clip_(x, scale=None):
1116
+
1117
+ @decorator_knowngood
1118
+ def _compilable_rmsnorm_clip_(x):
1119
+ x = list(map(promote, x))
949
1120
  norm = torch._foreach_norm(x)
950
- if scale is not None:
951
- torch._foreach_div_(norm, scale)
952
- torch._foreach_div_(x, norm)
953
- return x
1121
+ norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1122
+ torch._foreach_maximum_(norm, 1e-6)
1123
+ return torch._foreach_div(x, norm)
1124
+
1125
+
1126
+ def rmsnorm_clip_(x):
1127
+ x = list_guard(x)
1128
+ return _compilable_rmsnorm_clip_(x)
954
1129
 
955
1130
 
956
1131
  def mu_law_compress(x, mu=127.0):
@@ -990,18 +1165,24 @@ def identity(x):
990
1165
  return x
991
1166
 
992
1167
 
993
- def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
994
- torch._foreach_mul_(grad, 1 / scale)
995
- tanh = torch._foreach_tanh(grad)
996
- torch._foreach_abs_(grad)
997
- torch._foreach_log1p_(grad)
998
- grad = [p.copysign_(t) for t, p in zip(tanh, grad)] # torch doesn't have a foreach copysign
999
- torch._foreach_lerp_(grad, tanh, lerp) # sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9
1000
- torch._foreach_mul_(grad, scale)
1168
+ @decorator_knowngood
1169
+ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1170
+ g32 = list(map(promote, grad))
1171
+ [g.mul_(1 / scale) for g in g32]
1172
+ tanh = torch._foreach_tanh(g32)
1173
+ torch._foreach_abs_(g32)
1174
+ torch._foreach_log1p_(g32)
1175
+ [g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
1001
1176
 
1002
- torch._foreach_maximum_(grad, -2)
1003
- torch._foreach_minimum_(grad, 2)
1004
- return grad
1177
+ torch._foreach_maximum_(g32, -2)
1178
+ torch._foreach_minimum_(g32, 2)
1179
+ return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
1180
+
1181
+
1182
+ def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1183
+ grad = list_guard(grad)
1184
+ lerp, scale = scalar_guard(lerp, scale, grad[0])
1185
+ return _compilable_trust_region_clip_(grad, lerp, scale)
1005
1186
 
1006
1187
 
1007
1188
  @decorator
@@ -1040,60 +1221,57 @@ def update_triu_(q_state, materialised):
1040
1221
  copy_stochastic_(q, m)
1041
1222
 
1042
1223
 
1043
- class PSGDBase(StatefulOptimizer):
1044
- balance_probability: float = 0.01
1045
-
1046
- def __init__(self, parameters, groups, foreach: bool, stochastic_schedule: bool, clip_fn,
1047
- preconditioner_update_probability):
1048
- super().__init__(parameters, {**groups, 'stochastic_schedule': stochastic_schedule}, foreach)
1049
- self.rng = random.Random(0x1923213)
1050
- self._tiny = torch.finfo(torch.bfloat16).tiny
1051
- if clip_fn is None:
1052
- clip_fn = identity
1053
- if preconditioner_update_probability is None:
1054
- preconditioner_update_probability = precond_update_prob_schedule()
1055
- self.clip_fn = clip_fn
1056
- self.preconditioner_update_probability = preconditioner_update_probability
1057
-
1058
- def should_update(self, group, prob: Optional[float] = None, name: str = 'cumulative_prob'):
1059
- group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1060
- if prob is None:
1061
- prob = self.preconditioner_update_probability(group[f'{name}_prob_step'])
1062
- if group['stochastic_schedule']:
1063
- return self.rng.random() < prob
1064
- cumulative_prob = group.get(name, 0)
1065
- group[name] = cumulative_prob + prob
1066
- return int(group[name]) > int(cumulative_prob)
1067
-
1068
- def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
1069
- for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
1070
- psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
1071
-
1072
- if self.should_update(group, self.balance_probability, "balance_prob"):
1073
- for g, q in zip(grad_list, original_q if original_q else q_list):
1074
- if g.dim() > 1:
1075
- if store_triu_as_line:
1076
- psgd_balance_Q([q_ for _, q_ in q])
1077
- else:
1078
- psgd_balance_Q(q)
1079
-
1080
-
1081
- # TODO: Figure out why this sometimes crashes
1082
- # @decorator_knowngood
1083
- def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
1084
- clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
1224
+ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1225
+ name: str = 'cumulative_prob'):
1226
+ group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1227
+ if not isinstance(prob, float):
1228
+ prob = prob(group[f'{name}_prob_step'])
1229
+ if group['stochastic_schedule']:
1230
+ return rng.random() < prob
1231
+ cumulative_prob = state.get(name, 0)
1232
+ group[name] = cumulative_prob + prob
1233
+ return int(group[name]) > int(cumulative_prob)
1234
+
1235
+
1236
+ @decorator_knowngood
1237
+ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1085
1238
  md = min_dtype(list(cached_q) + [ea])
1086
1239
  args = [q.to(md) for q in cached_q]
1087
1240
  args = args + [ea.to(md)]
1088
1241
  new = torch.einsum(expr, *args)
1089
- new = new.to(torch.float32)
1090
- _compilable_update_([param], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
1242
+ if cast:
1243
+ return new.to(ea.dtype)
1244
+ return new
1245
+
1091
1246
 
1247
+ @decorator_knowngood
1248
+ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1249
+ precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1250
+ update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1092
1251
 
1093
- def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
1094
- clip_fn, caution, grad):
1095
- lr = scalar_guard(lr, param)
1096
- _compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
1252
+
1253
+ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1254
+ lr = scalar_guard(lr, param[0])
1255
+ _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
1256
+
1257
+
1258
+ @decorator_knowngood
1259
+ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1260
+ md = min_dtype(list(preconds) + [ea])
1261
+ args = [q.to(md) for q in preconds]
1262
+ args = args + args + [ea.to(md)]
1263
+ new = torch.einsum(expr, *args)
1264
+ return new.to(ea.dtype)
1265
+
1266
+
1267
+ def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1268
+ precond = psgd_precond_grad(expr, grad, *preconds)
1269
+ update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1270
+
1271
+
1272
+ def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1273
+ lr = scalar_guard(lr, param[0])
1274
+ _compilable_fused_psgd_precond_grad(expr, ea, param, lr, grad, decay, caution, *preconds)
1097
1275
 
1098
1276
 
1099
1277
  @decorator_knowngood
@@ -1122,7 +1300,7 @@ def caution(g, update):
1122
1300
  _compilable_cautioning_(g, update)
1123
1301
 
1124
1302
 
1125
- def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
1303
+ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
1126
1304
  """Anneal preconditioner update probability during beginning of training.
1127
1305
 
1128
1306
  PSGD benefits from more preconditioner updates at the beginning of training,