heavyball 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -163,18 +163,8 @@ class OrthoLaProp(C.BaseOpt):
163
163
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
164
 
165
165
 
166
- class ForeachAdamW(C.BaseOpt):
167
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
168
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
169
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
170
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
171
- defaults = locals()
172
- defaults.pop("self")
173
- params = defaults.pop("params")
174
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
175
-
176
166
 
177
- class OrthoAdamW(C.BaseOpt):
167
+ class LaPropOrtho(C.BaseOpt):
178
168
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
179
169
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
180
170
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
@@ -183,19 +173,7 @@ class OrthoAdamW(C.BaseOpt):
183
173
  defaults.pop("self")
184
174
  params = defaults.pop("params")
185
175
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
186
- C.orthogonalize_grad_to_param, C.scale_by_adam)
187
-
188
-
189
- class AdamWOrtho(C.BaseOpt):
190
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
191
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
192
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
193
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
194
- defaults = locals()
195
- defaults.pop("self")
196
- params = defaults.pop("params")
197
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
198
- C.orthogonalize_grad_to_param)
176
+ C.scale_by_laprop, C.orthogonalize_grad_to_param)
199
177
 
200
178
 
201
179
  class ForeachPSGDKron(C.BaseOpt):
@@ -216,7 +194,7 @@ class ForeachPSGDKron(C.BaseOpt):
216
194
  stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
217
195
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
218
196
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
219
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
197
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
220
198
  # expert parameters
221
199
  precond_init_scale=1.0, precond_lr=0.1):
222
200
  defaults = locals()
@@ -279,4 +257,4 @@ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD",
279
257
  "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
280
258
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
281
259
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
282
- "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
260
+ "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho']
heavyball/chainable.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import random
3
- from typing import Optional, Union, Literal
3
+ from typing import Optional, Union, Literal, List
4
4
 
5
5
  import torch
6
6
 
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
152
152
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
153
153
 
154
154
 
155
+ @zero_guard('exp_avg')
156
+ @no_state
157
+ def weight_decay_to_ema(group, update, grad, param, exp_avg):
158
+ utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
159
+ group['weight_decay_to_ema'] * group['lr'])
160
+ return update
161
+
162
+
163
+ @zero_guard('exp_avg')
164
+ @no_state
165
+ def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
166
+ utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
167
+ group['weight_decay_to_ema'] * group['lr'])
168
+ return update
169
+
170
+
155
171
  @zero_guard("exp_avg_sq")
156
172
  @no_state
157
173
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
@@ -295,6 +311,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
295
311
  return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
296
312
 
297
313
 
314
+ @zero_guard('momentum')
315
+ @no_state
316
+ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
317
+ return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
318
+
319
+
320
+ def _store_std(state, group, update, grad, param):
321
+ state['init_std'] = torch.std(grad, dim=0)
322
+
323
+
324
+ @general_guard("init_std", init_fn=_store_std)
325
+ @no_state
326
+ def mup_approx(group, updates, grads, params, init_std):
327
+ _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
328
+ _updates, _init_std = zip(*_updates)
329
+ utils.stochastic_multiply_(_updates, _init_std)
330
+ return updates
331
+
332
+
298
333
  @zero_guard("momentum")
299
334
  @no_state
300
335
  def heavyball_momentum(group, updates, grads, params, momentum):
@@ -312,7 +347,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
312
347
 
313
348
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
314
349
  fn = _optim_fns[inner]
315
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
350
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'], group['eps'])
316
351
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
317
352
 
318
353
  for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
@@ -364,10 +399,12 @@ def _update_psgd_cache(cached, Q_cache, q):
364
399
  return Q_cache
365
400
 
366
401
 
367
- def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
402
+ def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
368
403
  if group.get('is_cached', False):
369
- return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
370
- return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
404
+ out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
405
+ out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
406
+ group['caution'] = False # we already cautioned here - shouldn't do it again
407
+ return out
371
408
 
372
409
 
373
410
  def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
@@ -387,7 +424,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
387
424
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
388
425
  Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
389
426
  update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
390
- return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
427
+ return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
391
428
 
392
429
 
393
430
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
@@ -395,7 +432,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
395
432
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
396
433
  prob: Optional[callable] = None):
397
434
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
398
- precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
435
+ precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
399
436
  _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
400
437
  Q_mat, Q, exprs, prob)
401
438
  return precond
@@ -412,6 +449,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
412
449
  raise SkipUpdate
413
450
 
414
451
 
452
+ @no_state
453
+ def sign(group, update, grad, param, graft: bool = True):
454
+ return utils.sign_(update, graft)
455
+
456
+
415
457
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
416
458
  @no_state_no_foreach
417
459
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
@@ -437,8 +479,7 @@ def apply_to_idx(fn, idx):
437
479
  return _fn
438
480
 
439
481
 
440
- def chain(state: Union[callable, dict], group, grad, param, *fns):
441
- update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
482
+ def _inner_chain(state, group, update, grad, param, *fns):
442
483
  skip_update = False
443
484
  for fn in fns:
444
485
  try:
@@ -448,10 +489,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
448
489
  continue
449
490
  if update is None:
450
491
  break
492
+ return update, skip_update
493
+
494
+
495
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
496
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
497
+ update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
451
498
  if not skip_update and update is not None:
452
499
  utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
453
500
 
454
501
 
502
+ def create_branch(branches: List[List[callable]], merge_fn: callable):
503
+ def _branch(state, group, update, grad, param):
504
+ outputs = []
505
+ for branch in branches:
506
+ branch_update = [torch.clone(g, memory_format=torch.preserve_format) for u in update]
507
+ branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
508
+ if skip_update:
509
+ raise ValueError("Branches should not skip updates")
510
+ outputs.append(branch_update)
511
+ return merge_fn(outputs)
512
+
513
+ return _branch
514
+
515
+
455
516
  class ChainOpt(utils.StatefulOptimizer):
456
517
  promote: bool = False
457
518
 
@@ -467,6 +528,8 @@ class ChainOpt(utils.StatefulOptimizer):
467
528
  f'only supported with foreach=True (currently foreach={group["foreach"]}).')
468
529
  group['base_lr'] = group['lr']
469
530
 
531
+ caution = group['caution']
532
+
470
533
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
471
534
 
472
535
  if not vals:
@@ -492,6 +555,7 @@ class ChainOpt(utils.StatefulOptimizer):
492
555
  else:
493
556
  chain(self.state_, group, g, p, *self.fns)
494
557
 
558
+ group['caution'] = caution
495
559
  group['lr'] = group['prev_lr']
496
560
  group['step'] = None
497
561
 
heavyball/utils.py CHANGED
@@ -317,6 +317,19 @@ def nesterov_momentum(state, grad, beta):
317
317
  return grad
318
318
 
319
319
 
320
+ @decorator_knowngood
321
+ def _compilable_nesterov_ema_(state, grad, beta):
322
+ ema32 = _lerp32(state, grad, beta)
323
+ stochastic_add_(grad, ema32, 1)
324
+
325
+
326
+ def nesterov_ema(state, grad, beta):
327
+ state, grad = list_guard(state, grad)
328
+ beta = scalar_guard(beta, state[0])
329
+ _compilable_nesterov_ema_(state, grad, beta)
330
+ return grad
331
+
332
+
320
333
  def _compilable_grafting(magnitude, direction):
321
334
  return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
322
335
 
@@ -509,6 +522,19 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
509
522
  _compilable_stochastic_add_(x, y, alpha)
510
523
 
511
524
 
525
+ @decorator_knowngood
526
+ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
527
+ for x_, y_ in zip(x, y):
528
+ x32 = promote(x_)
529
+ y32 = promote(y_)
530
+ copy_stochastic_(x_, x32 * y32)
531
+
532
+
533
+ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
534
+ x, y = list_guard(x, y)
535
+ _compilable_stochastic_multiply_(x, y)
536
+
537
+
512
538
  @decorator
513
539
  def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
514
540
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
@@ -783,7 +809,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
783
809
 
784
810
 
785
811
  def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
786
- eps: float):
812
+ eps: float = 1e-8):
787
813
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
788
814
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
789
815
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -815,23 +841,23 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
815
841
 
816
842
  @decorator_knowngood
817
843
  def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
818
- beta2: Tensor, step: Tensor):
844
+ beta2: Tensor, step: Tensor, eps: Tensor):
819
845
  beta1 = beta_debias(beta1, step)
820
846
  beta2 = beta_debias(beta2, step)
821
847
 
822
848
  gp32 = list(map(promote, grad))
823
849
 
824
- denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
850
+ denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
825
851
  gp32 = torch._foreach_div(gp32, denom)
826
852
  gp32 = _lerp32(exp_avg, gp32, beta1)
827
853
 
828
854
  copy_stochastic_list_(grad, gp32)
829
855
 
830
856
 
831
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
857
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, eps: float = 1e-8):
832
858
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
833
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
834
- _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
859
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0], eps)
860
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
835
861
  return grad
836
862
 
837
863
 
@@ -970,6 +996,10 @@ def get_soap_precond_schedule(precond_scheduler):
970
996
  return _inner
971
997
 
972
998
 
999
+ def _max_idx(x: List[int]):
1000
+ return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
1001
+
1002
+
973
1003
  def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
974
1004
  """For a scalar or tensor t, we initialize its preconditioner Q and
975
1005
  reusable einsum expressions for updating Q and preconditioning gradient.
@@ -992,17 +1022,20 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
992
1022
 
993
1023
  scale = scale ** (1 / len(shape))
994
1024
 
1025
+ dim_diag = [False for _ in shape]
995
1026
  if memory_save_mode is None:
996
- dim_diag = [False for _ in shape]
1027
+ pass
997
1028
  elif memory_save_mode == "one_diag":
998
- rev_sorted_dims = np.argsort(shape)[::-1]
999
- dim_diag = [False for _ in shape]
1000
- dim_diag[rev_sorted_dims[0]] = True
1029
+ dim_diag[_max_idx(shape)] = True
1030
+ elif memory_save_mode == "smart_one_diag":
1031
+ sorted_shape = sorted(shape)
1032
+ if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
1033
+ dim_diag[_max_idx(shape)] = True
1001
1034
  elif memory_save_mode == "all_diag":
1002
1035
  dim_diag = [True for _ in shape]
1003
1036
  else:
1004
1037
  raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1005
- "[None, 'one_diag', 'all_diag']")
1038
+ "[None, 'one_diag', 'all_diag', 'smart_one_diag']")
1006
1039
 
1007
1040
  Q = []
1008
1041
  piece1A, piece2A, piece3A = ([], "", "")
@@ -1221,6 +1254,48 @@ def identity(x):
1221
1254
  return x
1222
1255
 
1223
1256
 
1257
+ @decorator_knowngood
1258
+ def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1259
+ ema32 = _lerp32(ema, p, ema_decay)
1260
+ _lerp32(p, ema32, 1 - weight_decay)
1261
+
1262
+
1263
+ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1264
+ p, ema = list_guard(p, ema)
1265
+ ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
1266
+ _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1267
+
1268
+
1269
+ @decorator_knowngood
1270
+ def _compilable_l1_weight_decay_to_ema_(p, ema, ema_deacy, weight_decay):
1271
+ ema32 = _lerp32(ema, p, ema_deacy)
1272
+ for p_, e_ in zip(p, ema32):
1273
+ p32 = promote(p)
1274
+ p32 = p32 + (p32 - e_).sign() * weight_decay
1275
+ copy_stochastic_(p_, p32)
1276
+
1277
+
1278
+ def l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1279
+ p, ema = list_guard(p, ema)
1280
+ ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
1281
+ _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1282
+
1283
+
1284
+ @decorator_knowngood
1285
+ def _compilable_sign_(grad: List[Tensor], graft: bool):
1286
+ for g_ in grad:
1287
+ gs = g_.sign()
1288
+ if graft:
1289
+ gs = _compilable_grafting(g_, gs)
1290
+ copy_stochastic_(g_, gs)
1291
+
1292
+
1293
+ def sign_(grad: List[Tensor], graft: bool = True):
1294
+ grad = list_guard(grad)
1295
+ _compilable_sign_(grad, graft)
1296
+ return grad
1297
+
1298
+
1224
1299
  @decorator_knowngood
1225
1300
  def _compilable_trust_region_clip_(grad, lerp, scale):
1226
1301
  # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
@@ -1300,7 +1375,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1300
1375
 
1301
1376
 
1302
1377
  @decorator_knowngood
1303
- def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1378
+ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1379
+ cast: bool = True):
1380
+ if caution:
1381
+ ea = _compilable_cautioning(grad, ea)
1304
1382
  md = min_dtype(list(cached_q) + [ea])
1305
1383
  args = [q.to(md) for q in cached_q]
1306
1384
  args = args + [ea.to(md)]
@@ -1312,8 +1390,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
1312
1390
 
1313
1391
  @decorator_knowngood
1314
1392
  def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1315
- precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1316
- update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1393
+ precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
1394
+ update_param_(param, precond, lr, decay, caution=False)
1317
1395
 
1318
1396
 
1319
1397
  def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
@@ -1322,7 +1400,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
1322
1400
 
1323
1401
 
1324
1402
  @decorator_knowngood
1325
- def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1403
+ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
1404
+ if caution:
1405
+ ea = _compilable_cautioning(grad, ea)
1326
1406
  md = min_dtype(list(preconds) + [ea])
1327
1407
  args = [q.to(md) for q in preconds]
1328
1408
  args = args + args + [ea.to(md)]
@@ -1332,8 +1412,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1332
1412
 
1333
1413
  @decorator_knowngood
1334
1414
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1335
- precond = psgd_precond_grad(expr, ea, *preconds)
1336
- update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1415
+ precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
1416
+ update_param_(param, precond, lr, decay, caution=False, grad=grad)
1337
1417
 
1338
1418
 
1339
1419
  def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.1
3
+ Version: 1.5.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
2
+ heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
3
+ heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
4
+ heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
6
+ heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.5.3.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=fz-jC7m7XIYNf4PRaJ0rkSnWPYzMWEK5JQl4vp_yw_w,14166
2
- heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
3
- heavyball/utils.py,sha256=hae6gPVONG5lZiKm-Wqk0Sjjq3prfZIjCP5UoWcpptA,50338
4
- heavyball-1.5.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.1.dist-info/METADATA,sha256=ww9KSe8MJDnjz1blmtnubpE20bkuXJ8NeMOeDK40OJk,43584
6
- heavyball-1.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.1.dist-info/RECORD,,