heavyball 1.3.0__tar.gz → 1.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.3.0 → heavyball-1.3.1}/PKG-INFO +1 -1
  2. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball/chainable.py +57 -19
  3. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball/utils.py +68 -35
  4. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.3.0 → heavyball-1.3.1}/setup.py +1 -1
  6. {heavyball-1.3.0 → heavyball-1.3.1}/LICENSE +0 -0
  7. {heavyball-1.3.0 → heavyball-1.3.1}/README.md +0 -0
  8. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball/__init__.py +0 -0
  9. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.3.0 → heavyball-1.3.1}/setup.cfg +0 -0
  14. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_caution.py +0 -0
  18. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_channels_last.py +0 -0
  19. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_closure.py +0 -0
  20. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_ema.py +0 -0
  21. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_foreach.py +0 -0
  22. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_hook.py +0 -0
  23. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_mars.py +0 -0
  24. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_memory.py +0 -0
  25. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_merge.py +0 -0
  26. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_no_grad.py +0 -0
  27. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_psgd.py +0 -0
  28. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_soap.py +0 -0
  29. {heavyball-1.3.0 → heavyball-1.3.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.3.0
3
+ Version: 1.3.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -307,7 +307,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
307
307
  return precond
308
308
 
309
309
 
310
- def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
310
+ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
311
311
  if prob is None:
312
312
  prob = utils.precond_update_prob_schedule()
313
313
  if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
@@ -322,6 +322,7 @@ def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[cal
322
322
  else:
323
323
  utils.psgd_balance_Q(Q)
324
324
 
325
+ _update_psgd_cache(cached, Q_cache, Q_mat)
325
326
 
326
327
  def _update_psgd_cache(cached, Q_cache, q):
327
328
  if not cached:
@@ -357,7 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
357
358
  old = update
358
359
  update = update.to(memory_format=torch.contiguous_format)
359
360
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
360
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
361
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
361
362
  return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
362
363
 
363
364
 
@@ -367,7 +368,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
367
368
  prob: Optional[callable] = None):
368
369
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
369
370
  precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
370
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
371
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
371
372
  return precond
372
373
 
373
374
 
@@ -376,7 +377,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
376
377
  def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
377
378
  prob: Optional[callable] = None):
378
379
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
379
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
380
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
380
381
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
381
382
  raise SkipUpdate
382
383
 
@@ -387,7 +388,7 @@ def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_
387
388
  prob: Optional[callable] = None):
388
389
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
389
390
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
390
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
391
+ _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
391
392
  raise SkipUpdate
392
393
 
393
394
 
@@ -422,6 +423,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
422
423
 
423
424
  class ChainOpt(utils.StatefulOptimizer):
424
425
  compile_step: bool = False
426
+ promote: bool = False
425
427
 
426
428
  def __init__(self, params, defaults, foreach: bool, *fns):
427
429
  super().__init__(params, defaults, foreach)
@@ -431,7 +433,7 @@ class ChainOpt(utils.StatefulOptimizer):
431
433
  if 'base_lr' not in group:
432
434
  group['base_lr'] = group['lr']
433
435
 
434
- vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
436
+ vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
435
437
  if not vals:
436
438
  return
437
439
  p, g = zip(*vals)
@@ -486,36 +488,72 @@ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd,
486
488
  scale_by_adam.get_fn(): update_by_adam, #
487
489
  scale_by_laprop.get_fn(): update_by_laprop, #
488
490
  scale_by_adopt.get_fn(): update_by_adopt}
491
+ _scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
492
+ update_by_psgd.get_fn(): scale_by_psgd, #
493
+ update_by_adam.get_fn(): scale_by_adam, #
494
+ update_by_laprop.get_fn(): scale_by_laprop, #
495
+ update_by_adopt.get_fn(): scale_by_adopt}
489
496
 
490
497
 
491
498
  class BaseOpt(ChainOpt):
499
+ """
500
+ Base Optimizer
501
+
502
+ compile_step: bool = False
503
+ Whether to change some internals to try to make the optimizer compilable
504
+ This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
505
+
506
+ promote: bool = False
507
+ Whether to promote the gradients to fp32 before applying the optimizer
508
+ Improves update quality for low-precision parameters, but increases costs
509
+ Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
510
+
511
+ gradient_clipping: str_or_fn = None
512
+ The function to use for clipping the incoming gradients, before any other transformations.
513
+ This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
514
+
515
+ update_clipping: str_or_fn = None
516
+ The function to use for clipping the outgoing updates before applying them, after all other transformations.
517
+ This will turn off
518
+ This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
519
+
520
+ """
521
+
492
522
  gradient_clipping: str_or_fn = None
493
523
  update_clipping: str_or_fn = None
494
524
  palm: bool = False
495
525
  auto_fuse: bool = True
496
526
 
497
527
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
498
- palm: bool = use_default, *fns, compile_step: bool = use_default):
528
+ palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
529
+ if not fns:
530
+ raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
531
+
532
+ args, kwargs = None, None
533
+ fn = fns[-1]
534
+ if isinstance(fn, functools.partial):
535
+ fn, args, kwargs = fn.func, fn.args, fn.keywords
536
+ if isinstance(fn, FunctionTransform):
537
+ fn = fn.get_fn()
538
+
499
539
  if default(update_clipping, self.update_clipping) is None:
500
- if fns and self.auto_fuse:
501
- args, kwargs = None, None
502
- fn = fns[-1]
503
- if isinstance(fn, functools.partial):
504
- fn, args, kwargs = fn.func, fn.args, fn.keywords
505
- if isinstance(fn, FunctionTransform):
506
- fn = fn.get_fn()
540
+ if self.auto_fuse:
507
541
  if fn in _scale_to_update_map:
508
542
  fn = _scale_to_update_map[fn]
509
543
  if args is not None:
510
544
  fn = functools.partial(fn, *args, **kwargs)
511
545
  fns = tuple(fns)[:-1] + (fn,)
512
- else:
513
- if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
514
- raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
515
-
516
- fns = tuple(fns)
546
+ elif fn in _scale_to_update_map_inv:
547
+ if not self.auto_fuse:
548
+ raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
549
+ "Manually select scale_by_* functions or set auto_fuse=True.")
550
+ fn = _scale_to_update_map_inv[fn]
551
+ if args is not None:
552
+ fn = functools.partial(fn, *args, **kwargs)
553
+ fns = tuple(fns)[:-1] + (fn,)
517
554
 
518
555
  self.compile_step = default(compile_step, self.compile_step)
556
+ self.promote = default(promote, self.promote)
519
557
  if default(palm, self.palm):
520
558
  fns = (palm_beta2,) + fns
521
559
  if default(gradient_clipping, self.gradient_clipping) is not None:
@@ -1101,65 +1101,98 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
1101
1101
 
1102
1102
 
1103
1103
  @decorator_knowngood
1104
- def _compilable_l2_clip_(x):
1104
+ def _compilable_l2_clip_(x, clip_at):
1105
1105
  ref = x
1106
1106
  x = list(map(promote, x))
1107
1107
  norm = torch._foreach_norm(x)
1108
- torch._foreach_maximum_(norm, 1e-8)
1108
+ torch._foreach_maximum_(norm, clip_at)
1109
1109
  out = torch._foreach_div(x, norm)
1110
1110
  return stochastic_round_list_(ref, out)
1111
1111
 
1112
1112
 
1113
- def l2_clip_(x):
1113
+ def l2_normalization_(x, clip_at: float = 1e-8):
1114
1114
  x = list_guard(x)
1115
- return _compilable_l2_clip_(x)
1115
+ return _compilable_l2_clip_(x, clip_at)
1116
+
1117
+
1118
+ def l2_clip_(x, clip_at: float = 1.):
1119
+ x = list_guard(x)
1120
+ return _compilable_l2_clip_(x, clip_at)
1116
1121
 
1117
1122
 
1118
1123
  @decorator_knowngood
1119
- def _compilable_rmsnorm_clip_(x):
1124
+ def _compilable_rmsnorm_clip_(x, clip_at):
1120
1125
  x = list(map(promote, x))
1121
1126
  norm = torch._foreach_norm(x)
1122
1127
  norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1123
- torch._foreach_maximum_(norm, 1e-6)
1128
+ torch._foreach_maximum_(norm, clip_at)
1124
1129
  return torch._foreach_div(x, norm)
1125
1130
 
1126
1131
 
1127
- def rmsnorm_clip_(x):
1132
+ def rmsnorm_clip_(x, clip_at: float = 1.0):
1128
1133
  x = list_guard(x)
1129
- return _compilable_rmsnorm_clip_(x)
1134
+ return _compilable_rmsnorm_clip_(x, clip_at)
1130
1135
 
1131
1136
 
1132
- def mu_law_compress(x, mu=127.0):
1137
+ def rmsnorm_normalize_(x, clip_at: float = 1e-6):
1138
+ x = list_guard(x)
1139
+ return _compilable_rmsnorm_clip_(x, clip_at)
1140
+
1141
+
1142
+ @decorator_knowngood
1143
+ def _compilable_mu_law_compress_(x, mu):
1144
+ """
1145
+ original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1133
1146
  """
1134
- Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1135
1147
 
1148
+ for x_ in x:
1149
+ xa = promote(x_.abs()) * mu
1150
+ xa = xa.log1p()
1151
+ xa = xa / math.log1p(mu)
1152
+ xa = xa.copysign(x_)
1153
+ copy_stochastic_(x_, xa)
1154
+
1155
+
1156
+ def mu_law_compress(x, mu=127.0):
1157
+ """
1136
1158
  μ-law compression
1137
1159
  Args:
1138
1160
  x: Input tensor
1139
1161
  mu: Compression parameter (default 127.0 for behavior similar to trust_region=1.5)
1140
1162
  """
1141
- xa = torch._foreach_abs_(x)
1142
- torch._foreach_mul_(xa, mu)
1143
- torch._foreach_log1p_(xa)
1144
- torch._foreach_div_(xa, math.log1p(mu))
1145
- return [xa_.copysign_(x_) for x_, xa_ in zip(x, xa)]
1163
+ x = list_guard(x)
1164
+ mu = scalar_guard(mu, x[0])
1165
+ _compilable_mu_law_compress(x, mu)
1166
+ return x
1146
1167
 
1147
1168
 
1148
- def a_law_compress(x, A=87.6):
1169
+ @decorator_knowngood
1170
+ def _compilable_a_law_compress_(x, A):
1171
+ """
1172
+ original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1149
1173
  """
1150
- Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1174
+ for x_ in x:
1175
+ xa = promote(x_.abs()) * A
1176
+ xa = torch.where(xa < 1, xa, 1 + xa.log())
1177
+ xa = xa.copysign(x_)
1178
+ xa = xa * (1 / (1 + math.log(A)))
1179
+ copy_stochastic_(x_, xa)
1151
1180
 
1181
+
1182
+ def a_law_compress(x, A=87.6):
1183
+ """
1152
1184
  A-law compression
1153
1185
  Args:
1154
1186
  x: Input tensor
1155
1187
  A: Compression parameter (default 87.6 - European PCM standard)
1188
+ :param x:
1189
+ :param A:
1190
+ :return:
1156
1191
  """
1157
- xa = torch._foreach_abs(x)
1158
- torch._foreach_mul_(xa, A)
1159
- [torch.where(x_ < 1, x_, 1 + torch.log_(x_), out=x_) for x_ in xa]
1160
- [xa_.copysign(x_) for x_, xa_ in zip(x, xa)]
1161
- torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
1162
- return xa
1192
+ x = list_guard(x)
1193
+ A = scalar_guard(A, x[0])
1194
+ _compilable_a_law_compress(x, A)
1195
+ return x
1163
1196
 
1164
1197
 
1165
1198
  def identity(x):
@@ -1167,24 +1200,24 @@ def identity(x):
1167
1200
 
1168
1201
 
1169
1202
  @decorator_knowngood
1170
- def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1203
+ def _compilable_trust_region_clip_(grad, lerp, scale):
1171
1204
  # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
1172
- g32 = list(map(promote, grad))
1173
- [g.mul_(1 / scale) for g in g32]
1174
- tanh = torch._foreach_tanh(g32)
1175
- torch._foreach_abs_(g32)
1176
- torch._foreach_log1p_(g32)
1177
- [g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
1178
-
1179
- torch._foreach_maximum_(g32, -2)
1180
- torch._foreach_minimum_(g32, 2)
1181
- return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
1205
+ for x_ in grad:
1206
+ x = promote(x_)
1207
+ x = x / scale
1208
+ tanh = x.tanh()
1209
+ x = x.abs().log1p()
1210
+ x = x.copysign(tanh) * (1 - lerp) + tanh * lerp
1211
+ x = x * scale
1212
+ x = x.clamp(min=-2, max=2)
1213
+ copy_stochastic_(x_, x)
1182
1214
 
1183
1215
 
1184
1216
  def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1185
1217
  grad = list_guard(grad)
1186
1218
  lerp, scale = scalar_guard(lerp, scale, grad[0])
1187
- return _compilable_trust_region_clip_(grad, lerp, scale)
1219
+ _compilable_trust_region_clip_(grad, lerp, scale)
1220
+ return grad
1188
1221
 
1189
1222
 
1190
1223
  @decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.3.0
3
+ Version: 1.3.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.3.0',
13
+ version='1.3.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes