heavyball 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -163,6 +163,19 @@ class OrthoLaProp(C.BaseOpt):
163
163
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
164
 
165
165
 
166
+
167
+ class LaPropOrtho(C.BaseOpt):
168
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
169
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
170
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
171
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
172
+ defaults = locals()
173
+ defaults.pop("self")
174
+ params = defaults.pop("params")
175
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
176
+ C.scale_by_laprop, C.orthogonalize_grad_to_param)
177
+
178
+
166
179
  class ForeachPSGDKron(C.BaseOpt):
167
180
  """
168
181
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -244,4 +257,4 @@ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD",
244
257
  "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
245
258
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
246
259
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
247
- "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
260
+ "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho']
heavyball/chainable.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import random
3
- from typing import Optional, Union, Literal
3
+ from typing import Optional, Union, Literal, List
4
4
 
5
5
  import torch
6
6
 
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
152
152
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
153
153
 
154
154
 
155
+ @zero_guard('exp_avg')
156
+ @no_state
157
+ def weight_decay_to_ema(group, update, grad, param, exp_avg):
158
+ utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
159
+ group['weight_decay_to_ema'] * group['lr'])
160
+ return update
161
+
162
+
163
+ @zero_guard('exp_avg')
164
+ @no_state
165
+ def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
166
+ utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
167
+ group['weight_decay_to_ema'] * group['lr'])
168
+ return update
169
+
170
+
155
171
  @zero_guard("exp_avg_sq")
156
172
  @no_state
157
173
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
@@ -295,6 +311,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
295
311
  return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
296
312
 
297
313
 
314
+ @zero_guard('momentum')
315
+ @no_state
316
+ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
317
+ return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
318
+
319
+
320
+ def _store_std(state, group, update, grad, param):
321
+ state['init_std'] = torch.std(grad, dim=0)
322
+
323
+
324
+ @general_guard("init_std", init_fn=_store_std)
325
+ @no_state
326
+ def mup_approx(group, updates, grads, params, init_std):
327
+ _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
328
+ _updates, _init_std = zip(*_updates)
329
+ utils.stochastic_multiply_(_updates, _init_std)
330
+ return updates
331
+
332
+
298
333
  @zero_guard("momentum")
299
334
  @no_state
300
335
  def heavyball_momentum(group, updates, grads, params, momentum):
@@ -312,7 +347,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
312
347
 
313
348
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
314
349
  fn = _optim_fns[inner]
315
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
350
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'], group['eps'])
316
351
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
317
352
 
318
353
  for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
@@ -414,6 +449,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
414
449
  raise SkipUpdate
415
450
 
416
451
 
452
+ @no_state
453
+ def sign(group, update, grad, param, graft: bool = True):
454
+ return utils.sign_(update, graft)
455
+
456
+
417
457
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
418
458
  @no_state_no_foreach
419
459
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
@@ -439,8 +479,7 @@ def apply_to_idx(fn, idx):
439
479
  return _fn
440
480
 
441
481
 
442
- def chain(state: Union[callable, dict], group, grad, param, *fns):
443
- update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
482
+ def _inner_chain(state, group, update, grad, param, *fns):
444
483
  skip_update = False
445
484
  for fn in fns:
446
485
  try:
@@ -450,10 +489,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
450
489
  continue
451
490
  if update is None:
452
491
  break
492
+ return update, skip_update
493
+
494
+
495
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
496
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
497
+ update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
453
498
  if not skip_update and update is not None:
454
499
  utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
455
500
 
456
501
 
502
+ def create_branch(branches: List[List[callable]], merge_fn: callable):
503
+ def _branch(state, group, update, grad, param):
504
+ outputs = []
505
+ for branch in branches:
506
+ branch_update = [torch.clone(g, memory_format=torch.preserve_format) for u in update]
507
+ branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
508
+ if skip_update:
509
+ raise ValueError("Branches should not skip updates")
510
+ outputs.append(branch_update)
511
+ return merge_fn(outputs)
512
+
513
+ return _branch
514
+
515
+
457
516
  class ChainOpt(utils.StatefulOptimizer):
458
517
  promote: bool = False
459
518
 
heavyball/utils.py CHANGED
@@ -317,6 +317,19 @@ def nesterov_momentum(state, grad, beta):
317
317
  return grad
318
318
 
319
319
 
320
+ @decorator_knowngood
321
+ def _compilable_nesterov_ema_(state, grad, beta):
322
+ ema32 = _lerp32(state, grad, beta)
323
+ stochastic_add_(grad, ema32, 1)
324
+
325
+
326
+ def nesterov_ema(state, grad, beta):
327
+ state, grad = list_guard(state, grad)
328
+ beta = scalar_guard(beta, state[0])
329
+ _compilable_nesterov_ema_(state, grad, beta)
330
+ return grad
331
+
332
+
320
333
  def _compilable_grafting(magnitude, direction):
321
334
  return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
322
335
 
@@ -509,6 +522,19 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
509
522
  _compilable_stochastic_add_(x, y, alpha)
510
523
 
511
524
 
525
+ @decorator_knowngood
526
+ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
527
+ for x_, y_ in zip(x, y):
528
+ x32 = promote(x_)
529
+ y32 = promote(y_)
530
+ copy_stochastic_(x_, x32 * y32)
531
+
532
+
533
+ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
534
+ x, y = list_guard(x, y)
535
+ _compilable_stochastic_multiply_(x, y)
536
+
537
+
512
538
  @decorator
513
539
  def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
514
540
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
@@ -783,7 +809,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
783
809
 
784
810
 
785
811
  def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
786
- eps: float):
812
+ eps: float = 1e-8):
787
813
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
788
814
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
789
815
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -815,23 +841,23 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
815
841
 
816
842
  @decorator_knowngood
817
843
  def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
818
- beta2: Tensor, step: Tensor):
844
+ beta2: Tensor, step: Tensor, eps: Tensor):
819
845
  beta1 = beta_debias(beta1, step)
820
846
  beta2 = beta_debias(beta2, step)
821
847
 
822
848
  gp32 = list(map(promote, grad))
823
849
 
824
- denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
850
+ denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
825
851
  gp32 = torch._foreach_div(gp32, denom)
826
852
  gp32 = _lerp32(exp_avg, gp32, beta1)
827
853
 
828
854
  copy_stochastic_list_(grad, gp32)
829
855
 
830
856
 
831
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
857
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, eps: float = 1e-8):
832
858
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
833
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
834
- _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
859
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0], eps)
860
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
835
861
  return grad
836
862
 
837
863
 
@@ -970,6 +996,10 @@ def get_soap_precond_schedule(precond_scheduler):
970
996
  return _inner
971
997
 
972
998
 
999
+ def _max_idx(x: List[int]):
1000
+ return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
1001
+
1002
+
973
1003
  def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
974
1004
  """For a scalar or tensor t, we initialize its preconditioner Q and
975
1005
  reusable einsum expressions for updating Q and preconditioning gradient.
@@ -992,17 +1022,20 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
992
1022
 
993
1023
  scale = scale ** (1 / len(shape))
994
1024
 
1025
+ dim_diag = [False for _ in shape]
995
1026
  if memory_save_mode is None:
996
- dim_diag = [False for _ in shape]
1027
+ pass
997
1028
  elif memory_save_mode == "one_diag":
998
- rev_sorted_dims = np.argsort(shape)[::-1]
999
- dim_diag = [False for _ in shape]
1000
- dim_diag[rev_sorted_dims[0]] = True
1029
+ dim_diag[_max_idx(shape)] = True
1030
+ elif memory_save_mode == "smart_one_diag":
1031
+ sorted_shape = sorted(shape)
1032
+ if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
1033
+ dim_diag[_max_idx(shape)] = True
1001
1034
  elif memory_save_mode == "all_diag":
1002
1035
  dim_diag = [True for _ in shape]
1003
1036
  else:
1004
1037
  raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1005
- "[None, 'one_diag', 'all_diag']")
1038
+ "[None, 'one_diag', 'all_diag', 'smart_one_diag']")
1006
1039
 
1007
1040
  Q = []
1008
1041
  piece1A, piece2A, piece3A = ([], "", "")
@@ -1221,6 +1254,48 @@ def identity(x):
1221
1254
  return x
1222
1255
 
1223
1256
 
1257
+ @decorator_knowngood
1258
+ def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1259
+ ema32 = _lerp32(ema, p, ema_decay)
1260
+ _lerp32(p, ema32, 1 - weight_decay)
1261
+
1262
+
1263
+ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1264
+ p, ema = list_guard(p, ema)
1265
+ ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
1266
+ _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1267
+
1268
+
1269
+ @decorator_knowngood
1270
+ def _compilable_l1_weight_decay_to_ema_(p, ema, ema_deacy, weight_decay):
1271
+ ema32 = _lerp32(ema, p, ema_deacy)
1272
+ for p_, e_ in zip(p, ema32):
1273
+ p32 = promote(p)
1274
+ p32 = p32 + (p32 - e_).sign() * weight_decay
1275
+ copy_stochastic_(p_, p32)
1276
+
1277
+
1278
+ def l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1279
+ p, ema = list_guard(p, ema)
1280
+ ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
1281
+ _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1282
+
1283
+
1284
+ @decorator_knowngood
1285
+ def _compilable_sign_(grad: List[Tensor], graft: bool):
1286
+ for g_ in grad:
1287
+ gs = g_.sign()
1288
+ if graft:
1289
+ gs = _compilable_grafting(g_, gs)
1290
+ copy_stochastic_(g_, gs)
1291
+
1292
+
1293
+ def sign_(grad: List[Tensor], graft: bool = True):
1294
+ grad = list_guard(grad)
1295
+ _compilable_sign_(grad, graft)
1296
+ return grad
1297
+
1298
+
1224
1299
  @decorator_knowngood
1225
1300
  def _compilable_trust_region_clip_(grad, lerp, scale):
1226
1301
  # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.2
3
+ Version: 1.5.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
2
+ heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
3
+ heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
4
+ heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
6
+ heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.5.3.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=f0wWIjsibgA4_YwkPP8HFD7-snggYsAOFc84W0WnNMA,12049
2
- heavyball/chainable.py,sha256=ygeQU-t3RT0Q1BWrEQ_0b4SlXYy8aGDt0DCZAfbiNiw,25040
3
- heavyball/utils.py,sha256=D7ENwrIex_dgFiUHezymmsIdruoQ4_hYztIolCXo2KE,50636
4
- heavyball-1.5.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.2.dist-info/METADATA,sha256=n_2fW7Wcz_btxBRWFibTe8wnM10B2su100bJzW0bfZY,43584
6
- heavyball-1.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.2.dist-info/RECORD,,