heavyball 1.2.3__tar.gz → 1.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.2.3 → heavyball-1.3.1}/PKG-INFO +1 -1
  2. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball/chainable.py +64 -25
  3. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball/utils.py +74 -42
  4. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.2.3 → heavyball-1.3.1}/setup.py +1 -1
  6. {heavyball-1.2.3 → heavyball-1.3.1}/LICENSE +0 -0
  7. {heavyball-1.2.3 → heavyball-1.3.1}/README.md +0 -0
  8. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball/__init__.py +0 -0
  9. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.2.3 → heavyball-1.3.1}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.2.3 → heavyball-1.3.1}/setup.cfg +0 -0
  14. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_caution.py +0 -0
  18. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_channels_last.py +0 -0
  19. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_closure.py +0 -0
  20. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_ema.py +0 -0
  21. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_foreach.py +0 -0
  22. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_hook.py +0 -0
  23. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_mars.py +0 -0
  24. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_memory.py +0 -0
  25. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_merge.py +0 -0
  26. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_no_grad.py +0 -0
  27. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_psgd.py +0 -0
  28. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_soap.py +0 -0
  29. {heavyball-1.2.3 → heavyball-1.3.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.3
3
+ Version: 1.3.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -307,7 +307,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
307
307
  return precond
308
308
 
309
309
 
310
- def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
310
+ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
311
311
  if prob is None:
312
312
  prob = utils.precond_update_prob_schedule()
313
313
  if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
@@ -322,6 +322,7 @@ def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[cal
322
322
  else:
323
323
  utils.psgd_balance_Q(Q)
324
324
 
325
+ _update_psgd_cache(cached, Q_cache, Q_mat)
325
326
 
326
327
  def _update_psgd_cache(cached, Q_cache, q):
327
328
  if not cached:
@@ -357,7 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
357
358
  old = update
358
359
  update = update.to(memory_format=torch.contiguous_format)
359
360
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
360
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
361
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
361
362
  return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
362
363
 
363
364
 
@@ -367,7 +368,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
367
368
  prob: Optional[callable] = None):
368
369
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
369
370
  precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
370
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
371
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
371
372
  return precond
372
373
 
373
374
 
@@ -376,7 +377,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
376
377
  def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
377
378
  prob: Optional[callable] = None):
378
379
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
379
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
380
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
380
381
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
381
382
  raise SkipUpdate
382
383
 
@@ -387,7 +388,7 @@ def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_
387
388
  prob: Optional[callable] = None):
388
389
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
389
390
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
390
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
391
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
391
392
  raise SkipUpdate
392
393
 
393
394
 
@@ -422,6 +423,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
422
423
 
423
424
  class ChainOpt(utils.StatefulOptimizer):
424
425
  compile_step: bool = False
426
+ promote: bool = False
425
427
 
426
428
  def __init__(self, params, defaults, foreach: bool, *fns):
427
429
  super().__init__(params, defaults, foreach)
@@ -431,21 +433,22 @@ class ChainOpt(utils.StatefulOptimizer):
431
433
  if 'base_lr' not in group:
432
434
  group['base_lr'] = group['lr']
433
435
 
434
- vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
436
+ vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
435
437
  if not vals:
436
438
  return
437
439
  p, g = zip(*vals)
438
440
 
439
441
  for param in p:
440
442
  state = self.state_(param)
441
- if 'step' not in state:
442
- if self.compile_step:
443
- step = utils.scalar_guard(0, param)
444
- state['step'] = step
445
- step = state['step'].add_(1)
443
+ if 'step' in state:
444
+ step = state['step']
445
+ elif self.compile_step:
446
+ step = utils.scalar_guard(0, param)
447
+ else:
448
+ step = 0
446
449
  break
447
450
 
448
- group['step'] = step
451
+ group['step'] = state['step'] = step = step + 1
449
452
 
450
453
  if group['warmup_steps'] and step < group['warmup_steps']:
451
454
  group['lr'] = group['base_lr'] * step / group['warmup_steps']
@@ -485,36 +488,72 @@ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd,
485
488
  scale_by_adam.get_fn(): update_by_adam, #
486
489
  scale_by_laprop.get_fn(): update_by_laprop, #
487
490
  scale_by_adopt.get_fn(): update_by_adopt}
491
+ _scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
492
+ update_by_psgd.get_fn(): scale_by_psgd, #
493
+ update_by_adam.get_fn(): scale_by_adam, #
494
+ update_by_laprop.get_fn(): scale_by_laprop, #
495
+ update_by_adopt.get_fn(): scale_by_adopt}
488
496
 
489
497
 
490
498
  class BaseOpt(ChainOpt):
499
+ """
500
+ Base Optimizer
501
+
502
+ compile_step: bool = False
503
+ Whether to change some internals to try to make the optimizer compilable
504
+ This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
505
+
506
+ promote: bool = False
507
+ Whether to promote the gradients to fp32 before applying the optimizer
508
+ Improves update quality for low-precision parameters, but increases costs
509
+ Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
510
+
511
+ gradient_clipping: str_or_fn = None
512
+ The function to use for clipping the incoming gradients, before any other transformations.
513
+ This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
514
+
515
+ update_clipping: str_or_fn = None
516
+ The function to use for clipping the outgoing updates before applying them, after all other transformations.
517
+ This will turn off
518
+ This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
519
+
520
+ """
521
+
491
522
  gradient_clipping: str_or_fn = None
492
523
  update_clipping: str_or_fn = None
493
524
  palm: bool = False
494
525
  auto_fuse: bool = True
495
526
 
496
527
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
497
- palm: bool = use_default, compile_step: bool = use_default, *fns):
528
+ palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
529
+ if not fns:
530
+ raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
531
+
532
+ args, kwargs = None, None
533
+ fn = fns[-1]
534
+ if isinstance(fn, functools.partial):
535
+ fn, args, kwargs = fn.func, fn.args, fn.keywords
536
+ if isinstance(fn, FunctionTransform):
537
+ fn = fn.get_fn()
538
+
498
539
  if default(update_clipping, self.update_clipping) is None:
499
- if fns and self.auto_fuse:
500
- args, kwargs = None, None
501
- fn = fns[-1]
502
- if isinstance(fn, functools.partial):
503
- fn, args, kwargs = fn.func, fn.args, fn.keywords
504
- if isinstance(fn, FunctionTransform):
505
- fn = fn.get_fn()
540
+ if self.auto_fuse:
506
541
  if fn in _scale_to_update_map:
507
542
  fn = _scale_to_update_map[fn]
508
543
  if args is not None:
509
544
  fn = functools.partial(fn, *args, **kwargs)
510
545
  fns = tuple(fns)[:-1] + (fn,)
511
- else:
512
- if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
513
- raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
514
-
515
- fns = tuple(fns)
546
+ elif fn in _scale_to_update_map_inv:
547
+ if not self.auto_fuse:
548
+ raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
549
+ "Manually select scale_by_* functions or set auto_fuse=True.")
550
+ fn = _scale_to_update_map_inv[fn]
551
+ if args is not None:
552
+ fn = functools.partial(fn, *args, **kwargs)
553
+ fns = tuple(fns)[:-1] + (fn,)
516
554
 
517
555
  self.compile_step = default(compile_step, self.compile_step)
556
+ self.promote = default(promote, self.promote)
518
557
  if default(palm, self.palm):
519
558
  fns = (palm_beta2,) + fns
520
559
  if default(gradient_clipping, self.gradient_clipping) is not None:
@@ -88,7 +88,7 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
88
88
  except ZeroDivisionError:
89
89
  ckp1 = 0
90
90
 
91
- update, parameters, z = list_guard(update, parameters, z)
91
+ update, parameters, z, grad = list_guard(update, parameters, z, grad)
92
92
  lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
93
93
  _compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
94
94
  return weight_sum
@@ -912,13 +912,12 @@ def copy_stochastic_(target: Tensor, source: Tensor):
912
912
  @decorator_knowngood
913
913
  def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
914
914
  g: List[Optional[Tensor]]):
915
- u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
916
- p32, u32 = [list(map(promote, x)) for x in [p, u]]
917
-
918
- for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
915
+ for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
916
+ u_ = promote(u_.view_as(p_))
917
+ p32_ = promote(p_)
919
918
  if caution:
920
- u32_ = _compilable_cautioning(promote(g_), u32_)
921
- p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
919
+ u_ = _compilable_cautioning(promote(g_), u_)
920
+ p32_ = p32_ * (1 - decay * lr) + u_ * -lr
922
921
  copy_stochastic_(p_, p32_)
923
922
 
924
923
 
@@ -1102,65 +1101,98 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
1102
1101
 
1103
1102
 
1104
1103
  @decorator_knowngood
1105
- def _compilable_l2_clip_(x):
1104
+ def _compilable_l2_clip_(x, clip_at):
1106
1105
  ref = x
1107
1106
  x = list(map(promote, x))
1108
1107
  norm = torch._foreach_norm(x)
1109
- torch._foreach_maximum_(norm, 1e-8)
1108
+ torch._foreach_maximum_(norm, clip_at)
1110
1109
  out = torch._foreach_div(x, norm)
1111
1110
  return stochastic_round_list_(ref, out)
1112
1111
 
1113
1112
 
1114
- def l2_clip_(x):
1113
+ def l2_normalization_(x, clip_at: float = 1e-8):
1114
+ x = list_guard(x)
1115
+ return _compilable_l2_clip_(x, clip_at)
1116
+
1117
+
1118
+ def l2_clip_(x, clip_at: float = 1.):
1115
1119
  x = list_guard(x)
1116
- return _compilable_l2_clip_(x)
1120
+ return _compilable_l2_clip_(x, clip_at)
1117
1121
 
1118
1122
 
1119
1123
  @decorator_knowngood
1120
- def _compilable_rmsnorm_clip_(x):
1124
+ def _compilable_rmsnorm_clip_(x, clip_at):
1121
1125
  x = list(map(promote, x))
1122
1126
  norm = torch._foreach_norm(x)
1123
1127
  norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1124
- torch._foreach_maximum_(norm, 1e-6)
1128
+ torch._foreach_maximum_(norm, clip_at)
1125
1129
  return torch._foreach_div(x, norm)
1126
1130
 
1127
1131
 
1128
- def rmsnorm_clip_(x):
1132
+ def rmsnorm_clip_(x, clip_at: float = 1.0):
1129
1133
  x = list_guard(x)
1130
- return _compilable_rmsnorm_clip_(x)
1134
+ return _compilable_rmsnorm_clip_(x, clip_at)
1131
1135
 
1132
1136
 
1133
- def mu_law_compress(x, mu=127.0):
1137
+ def rmsnorm_normalize_(x, clip_at: float = 1e-6):
1138
+ x = list_guard(x)
1139
+ return _compilable_rmsnorm_clip_(x, clip_at)
1140
+
1141
+
1142
+ @decorator_knowngood
1143
+ def _compilable_mu_law_compress_(x, mu):
1134
1144
  """
1135
- Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1145
+ original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1146
+ """
1147
+
1148
+ for x_ in x:
1149
+ xa = promote(x_.abs()) * mu
1150
+ xa = xa.log1p()
1151
+ xa = xa / math.log1p(mu)
1152
+ xa = xa.copysign(x_)
1153
+ copy_stochastic_(x_, xa)
1136
1154
 
1155
+
1156
+ def mu_law_compress(x, mu=127.0):
1157
+ """
1137
1158
  μ-law compression
1138
1159
  Args:
1139
1160
  x: Input tensor
1140
1161
  mu: Compression parameter (default 127.0 for behavior similar to trust_region=1.5)
1141
1162
  """
1142
- xa = torch._foreach_abs_(x)
1143
- torch._foreach_mul_(xa, mu)
1144
- torch._foreach_log1p_(xa)
1145
- torch._foreach_div_(xa, math.log1p(mu))
1146
- return [xa_.copysign_(x_) for x_, xa_ in zip(x, xa)]
1163
+ x = list_guard(x)
1164
+ mu = scalar_guard(mu, x[0])
1165
+ _compilable_mu_law_compress(x, mu)
1166
+ return x
1147
1167
 
1148
1168
 
1149
- def a_law_compress(x, A=87.6):
1169
+ @decorator_knowngood
1170
+ def _compilable_a_law_compress_(x, A):
1150
1171
  """
1151
- Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1172
+ original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1173
+ """
1174
+ for x_ in x:
1175
+ xa = promote(x_.abs()) * A
1176
+ xa = torch.where(xa < 1, xa, 1 + xa.log())
1177
+ xa = xa.copysign(x_)
1178
+ xa = xa * (1 / (1 + math.log(A)))
1179
+ copy_stochastic_(x_, xa)
1180
+
1152
1181
 
1182
+ def a_law_compress(x, A=87.6):
1183
+ """
1153
1184
  A-law compression
1154
1185
  Args:
1155
1186
  x: Input tensor
1156
1187
  A: Compression parameter (default 87.6 - European PCM standard)
1188
+ :param x:
1189
+ :param A:
1190
+ :return:
1157
1191
  """
1158
- xa = torch._foreach_abs(x)
1159
- torch._foreach_mul_(xa, A)
1160
- [torch.where(x_ < 1, x_, 1 + torch.log_(x_), out=x_) for x_ in xa]
1161
- [xa_.copysign(x_) for x_, xa_ in zip(x, xa)]
1162
- torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
1163
- return xa
1192
+ x = list_guard(x)
1193
+ A = scalar_guard(A, x[0])
1194
+ _compilable_a_law_compress(x, A)
1195
+ return x
1164
1196
 
1165
1197
 
1166
1198
  def identity(x):
@@ -1168,24 +1200,24 @@ def identity(x):
1168
1200
 
1169
1201
 
1170
1202
  @decorator_knowngood
1171
- def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1203
+ def _compilable_trust_region_clip_(grad, lerp, scale):
1172
1204
  # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
1173
- g32 = list(map(promote, grad))
1174
- [g.mul_(1 / scale) for g in g32]
1175
- tanh = torch._foreach_tanh(g32)
1176
- torch._foreach_abs_(g32)
1177
- torch._foreach_log1p_(g32)
1178
- [g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
1179
-
1180
- torch._foreach_maximum_(g32, -2)
1181
- torch._foreach_minimum_(g32, 2)
1182
- return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
1205
+ for x_ in grad:
1206
+ x = promote(x_)
1207
+ x = x / scale
1208
+ tanh = x.tanh()
1209
+ x = x.abs().log1p()
1210
+ x = x.copysign(tanh) * (1 - lerp) + tanh * lerp
1211
+ x = x * scale
1212
+ x = x.clamp(min=-2, max=2)
1213
+ copy_stochastic_(x_, x)
1183
1214
 
1184
1215
 
1185
1216
  def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1186
1217
  grad = list_guard(grad)
1187
1218
  lerp, scale = scalar_guard(lerp, scale, grad[0])
1188
- return _compilable_trust_region_clip_(grad, lerp, scale)
1219
+ _compilable_trust_region_clip_(grad, lerp, scale)
1220
+ return grad
1189
1221
 
1190
1222
 
1191
1223
  @decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.3
3
+ Version: 1.3.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.2.3',
13
+ version='1.3.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes