heavyball 1.3.0__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.3.0 → heavyball-1.4.0}/PKG-INFO +1 -1
  2. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball/chainable.py +91 -37
  3. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball/utils.py +85 -52
  4. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.3.0 → heavyball-1.4.0}/setup.py +1 -1
  6. {heavyball-1.3.0 → heavyball-1.4.0}/LICENSE +0 -0
  7. {heavyball-1.3.0 → heavyball-1.4.0}/README.md +0 -0
  8. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball/__init__.py +0 -0
  9. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.3.0 → heavyball-1.4.0}/setup.cfg +0 -0
  14. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_caution.py +0 -0
  18. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_channels_last.py +0 -0
  19. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_closure.py +0 -0
  20. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_ema.py +0 -0
  21. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_foreach.py +0 -0
  22. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_hook.py +0 -0
  23. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_mars.py +0 -0
  24. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_memory.py +0 -0
  25. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_merge.py +0 -0
  26. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_no_grad.py +0 -0
  27. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_psgd.py +0 -0
  28. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_soap.py +0 -0
  29. {heavyball-1.3.0 → heavyball-1.4.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,6 +1,5 @@
1
1
  import functools
2
2
  import random
3
- import warnings
4
3
  from typing import Optional, Union, Literal
5
4
 
6
5
  import torch
@@ -85,10 +84,11 @@ class CopyGuard(FunctionTransform):
85
84
 
86
85
 
87
86
  class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
88
- def __init__(self, fn, names, init_fn):
87
+ def __init__(self, fn, names, init_fn, skip_first: bool = True):
89
88
  super().__init__(fn)
90
89
  self.names = names
91
90
  self.init_fn = init_fn
91
+ self.skip_first = skip_first
92
92
 
93
93
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
94
94
  vars = []
@@ -97,7 +97,7 @@ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the ge
97
97
  st = state(p)
98
98
  skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
99
99
  vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
100
- if skip_update:
100
+ if skip_update and self.skip_first:
101
101
  raise SkipUpdate
102
102
  return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
103
103
 
@@ -109,8 +109,17 @@ class NoState(FunctionTransform):
109
109
 
110
110
  class NoStateNoForeach(FunctionTransform):
111
111
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
112
+ updates = []
113
+ skip_update = False
112
114
  for a in zip(update, grad, param, *args):
113
- return self.fn(group, *a, **kwargs)
115
+ try:
116
+ updates.append(self.fn(group, *a, **kwargs))
117
+ except SkipUpdate:
118
+ skip_update = True
119
+ pass
120
+ if skip_update:
121
+ raise SkipUpdate
122
+ return updates
114
123
 
115
124
 
116
125
  def zero_guard(*names):
@@ -118,11 +127,11 @@ def zero_guard(*names):
118
127
 
119
128
 
120
129
  def copy_guard(index, *names):
121
- return functools.partial(CopyGuard, index=index, names=names)
130
+ return functools.partial(CopyGuard, index=index, names=names,)
122
131
 
123
132
 
124
- def general_guard(*names, init_fn):
125
- return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
133
+ def general_guard(*names, init_fn, skip_first: bool = True):
134
+ return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first)
126
135
 
127
136
 
128
137
  def no_state(fn):
@@ -307,21 +316,22 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
307
316
  return precond
308
317
 
309
318
 
310
- def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
319
+ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
311
320
  if prob is None:
312
321
  prob = utils.precond_update_prob_schedule()
313
322
  if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
314
- return
323
+ return Q_mat
315
324
 
316
- Q = [utils.promote(q_) for q_ in Q]
317
325
  utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
318
326
 
319
- if grad.dim() > 1 and precond_schedule(group, balance_probability, "balance_prob"):
327
+ if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
320
328
  if group['store_triu_as_line']:
321
329
  utils.psgd_balance_Q([q_ for _, q_ in Q])
322
330
  else:
323
331
  utils.psgd_balance_Q(Q)
324
332
 
333
+ return _update_psgd_cache(cached, Q_cache, Q_mat)
334
+
325
335
 
326
336
  def _update_psgd_cache(cached, Q_cache, q):
327
337
  if not cached:
@@ -350,44 +360,47 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
350
360
  group['caution'], *Q_mat)
351
361
 
352
362
 
353
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
363
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
354
364
  @no_state_no_foreach
355
365
  def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
356
366
  prob: Optional[callable] = None):
357
- old = update
358
367
  update = update.to(memory_format=torch.contiguous_format)
359
368
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
360
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
369
+ Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
370
+ update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
361
371
  return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
362
372
 
363
373
 
364
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
374
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
365
375
  @no_state_no_foreach
366
376
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
367
377
  prob: Optional[callable] = None):
368
378
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
369
379
  precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
370
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
380
+ _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
381
+ Q_mat, Q, exprs, prob)
371
382
  return precond
372
383
 
373
384
 
374
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
385
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
375
386
  @no_state_no_foreach
376
387
  def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
377
388
  prob: Optional[callable] = None):
378
389
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
379
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
390
+ Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
391
+ update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
380
392
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
381
393
  raise SkipUpdate
382
394
 
383
395
 
384
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
396
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
385
397
  @no_state_no_foreach
386
398
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
387
399
  prob: Optional[callable] = None):
388
400
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
389
401
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
390
- _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
402
+ _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
403
+ Q_mat, Q, exprs, prob)
391
404
  raise SkipUpdate
392
405
 
393
406
 
@@ -421,7 +434,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
421
434
 
422
435
 
423
436
  class ChainOpt(utils.StatefulOptimizer):
424
- compile_step: bool = False
437
+ promote: bool = False
425
438
 
426
439
  def __init__(self, params, defaults, foreach: bool, *fns):
427
440
  super().__init__(params, defaults, foreach)
@@ -430,8 +443,12 @@ class ChainOpt(utils.StatefulOptimizer):
430
443
  def _step(self, group):
431
444
  if 'base_lr' not in group:
432
445
  group['base_lr'] = group['lr']
446
+ if 'prev_lr' in group and group['prev_lr'] != group['lr']:
447
+ utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and '
448
+ f'only supported with foreach=True (currently foreach={group["foreach"]}).')
449
+ group['base_lr'] = group['lr']
433
450
 
434
- vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
451
+ vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
435
452
  if not vals:
436
453
  return
437
454
  p, g = zip(*vals)
@@ -449,9 +466,10 @@ class ChainOpt(utils.StatefulOptimizer):
449
466
  group['step'] = state['step'] = step = step + 1
450
467
 
451
468
  if group['warmup_steps'] and step < group['warmup_steps']:
452
- group['lr'] = group['base_lr'] * step / group['warmup_steps']
469
+ group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']
470
+
453
471
  else:
454
- group['lr'] = group['base_lr']
472
+ group['prev_lr'] = group['lr'] = group['base_lr']
455
473
 
456
474
  if not group['foreach'] or len(p) == 1:
457
475
  for param, grad in zip(p, g):
@@ -486,36 +504,72 @@ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd,
486
504
  scale_by_adam.get_fn(): update_by_adam, #
487
505
  scale_by_laprop.get_fn(): update_by_laprop, #
488
506
  scale_by_adopt.get_fn(): update_by_adopt}
507
+ _scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
508
+ update_by_psgd.get_fn(): scale_by_psgd, #
509
+ update_by_adam.get_fn(): scale_by_adam, #
510
+ update_by_laprop.get_fn(): scale_by_laprop, #
511
+ update_by_adopt.get_fn(): scale_by_adopt}
489
512
 
490
513
 
491
514
  class BaseOpt(ChainOpt):
515
+ """
516
+ Base Optimizer
517
+
518
+ compile_step: bool = False
519
+ Whether to change some internals to try to make the optimizer compilable
520
+ This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
521
+
522
+ promote: bool = False
523
+ Whether to promote the gradients to fp32 before applying the optimizer
524
+ Improves update quality for low-precision parameters, but increases costs
525
+ Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
526
+
527
+ gradient_clipping: str_or_fn = None
528
+ The function to use for clipping the incoming gradients, before any other transformations.
529
+ This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
530
+
531
+ update_clipping: str_or_fn = None
532
+ The function to use for clipping the outgoing updates before applying them, after all other transformations.
533
+ This will turn off
534
+ This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
535
+
536
+ """
537
+
492
538
  gradient_clipping: str_or_fn = None
493
539
  update_clipping: str_or_fn = None
494
540
  palm: bool = False
495
541
  auto_fuse: bool = True
496
542
 
497
543
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
498
- palm: bool = use_default, *fns, compile_step: bool = use_default):
544
+ palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
545
+ if not fns:
546
+ raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
547
+
548
+ args, kwargs = None, None
549
+ fn = fns[-1]
550
+ if isinstance(fn, functools.partial):
551
+ fn, args, kwargs = fn.func, fn.args, fn.keywords
552
+ if isinstance(fn, FunctionTransform):
553
+ fn = fn.get_fn()
554
+
499
555
  if default(update_clipping, self.update_clipping) is None:
500
- if fns and self.auto_fuse:
501
- args, kwargs = None, None
502
- fn = fns[-1]
503
- if isinstance(fn, functools.partial):
504
- fn, args, kwargs = fn.func, fn.args, fn.keywords
505
- if isinstance(fn, FunctionTransform):
506
- fn = fn.get_fn()
556
+ if self.auto_fuse:
507
557
  if fn in _scale_to_update_map:
508
558
  fn = _scale_to_update_map[fn]
509
559
  if args is not None:
510
560
  fn = functools.partial(fn, *args, **kwargs)
511
561
  fns = tuple(fns)[:-1] + (fn,)
512
- else:
513
- if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
514
- raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
515
-
516
- fns = tuple(fns)
562
+ elif fn in _scale_to_update_map_inv:
563
+ if not self.auto_fuse:
564
+ raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
565
+ "Manually select scale_by_* functions or set auto_fuse=True.")
566
+ fn = _scale_to_update_map_inv[fn]
567
+ if args is not None:
568
+ fn = functools.partial(fn, *args, **kwargs)
569
+ fns = tuple(fns)[:-1] + (fn,)
517
570
 
518
571
  self.compile_step = default(compile_step, self.compile_step)
572
+ self.promote = default(promote, self.promote)
519
573
  if default(palm, self.palm):
520
574
  fns = (palm_beta2,) + fns
521
575
  if default(gradient_clipping, self.gradient_clipping) is not None:
@@ -193,12 +193,12 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
193
193
  return grad
194
194
 
195
195
 
196
- # TODO: This lerp was fucked - check other lerps
197
196
  @decorator_knowngood
198
197
  def _compilable_exp_avg_(state, grad, beta):
199
- s32 = [s.lerp(g, 1 - beta) for s, g in zip(promote(state), promote(grad))]
200
- copy_stochastic_list_(state, s32)
201
- copy_stochastic_list_(grad, s32)
198
+ for s, g in zip(state, grad):
199
+ lerped = s.lerp(g, 1 - beta)
200
+ copy_stochastic_(s, lerped)
201
+ copy_stochastic_(g, lerped)
202
202
 
203
203
 
204
204
  def scale_by_exp_avg_(state, grad, beta):
@@ -592,6 +592,7 @@ def project(grad, Q, back: bool):
592
592
 
593
593
  class StatefulOptimizer(torch.optim.Optimizer):
594
594
  ema_decay: float = 0.001
595
+ compile_step: bool = False
595
596
 
596
597
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
597
598
  super().__init__(params, {**defaults, 'foreach': foreach})
@@ -637,6 +638,10 @@ class StatefulOptimizer(torch.optim.Optimizer):
637
638
 
638
639
  p.grad = None
639
640
 
641
+ if self.compile_step:
642
+ yield p, grad
643
+ continue
644
+
640
645
  p_views = merge_group(group, p)
641
646
  if grad is not None:
642
647
  grad = merge_group(group, grad)
@@ -1030,7 +1035,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
1030
1035
  V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
1031
1036
  eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1032
1037
  eps *= G.norm() / G.numel()
1033
- G += V * eps
1038
+ G = G + V * eps
1034
1039
  md = min_dtype(Q + [G])
1035
1040
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1036
1041
  order = G.dim()
@@ -1078,88 +1083,115 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
1078
1083
  term1 = promote(torch.einsum(exprG, A, A))
1079
1084
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1080
1085
 
1081
- term2 += term1 # a + b
1082
- term1 *= 2 # 2a
1083
- if term1.dtype == term2.dtype:
1084
- term1 -= term2 # 2a - (a + b) == a - b
1085
- else:
1086
- term1 = term1 - term2
1086
+ term1, term2 = term1 - term2, term1 + term2
1087
1087
 
1088
1088
  term1 *= precond_lr
1089
1089
  norm = term2.norm(float('inf'))
1090
1090
  if q.dim() < 2:
1091
- term1 *= q.to(term1.dtype)
1092
- term1 /= norm.clamp_(min=tiny_bf16)
1091
+ term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
1093
1092
  else:
1094
1093
  torch.triu(term1, out=term1)
1095
- term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
1096
- torch.matmul(term1, q, out=term1)
1094
+ term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1095
+ term1 = torch.mm(term1, q)
1097
1096
  if store_triu_as_line:
1098
1097
  term1 = triu_to_line([term1])[0][1]
1099
1098
  o = o[1]
1100
- stochastic_add_([o], [term1], -1)
1099
+ stochastic_add_(o, term1, -1)
1101
1100
 
1102
1101
 
1103
1102
  @decorator_knowngood
1104
- def _compilable_l2_clip_(x):
1103
+ def _compilable_l2_clip_(x, clip_at):
1105
1104
  ref = x
1106
1105
  x = list(map(promote, x))
1107
1106
  norm = torch._foreach_norm(x)
1108
- torch._foreach_maximum_(norm, 1e-8)
1107
+ torch._foreach_maximum_(norm, clip_at)
1109
1108
  out = torch._foreach_div(x, norm)
1110
1109
  return stochastic_round_list_(ref, out)
1111
1110
 
1112
1111
 
1113
- def l2_clip_(x):
1112
+ def l2_normalization_(x, clip_at: float = 1e-8):
1114
1113
  x = list_guard(x)
1115
- return _compilable_l2_clip_(x)
1114
+ return _compilable_l2_clip_(x, clip_at)
1115
+
1116
+
1117
+ def l2_clip_(x, clip_at: float = 1.):
1118
+ x = list_guard(x)
1119
+ return _compilable_l2_clip_(x, clip_at)
1116
1120
 
1117
1121
 
1118
1122
  @decorator_knowngood
1119
- def _compilable_rmsnorm_clip_(x):
1123
+ def _compilable_rmsnorm_clip_(x, clip_at):
1120
1124
  x = list(map(promote, x))
1121
1125
  norm = torch._foreach_norm(x)
1122
1126
  norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
1123
- torch._foreach_maximum_(norm, 1e-6)
1127
+ torch._foreach_maximum_(norm, clip_at)
1124
1128
  return torch._foreach_div(x, norm)
1125
1129
 
1126
1130
 
1127
- def rmsnorm_clip_(x):
1131
+ def rmsnorm_clip_(x, clip_at: float = 1.0):
1128
1132
  x = list_guard(x)
1129
- return _compilable_rmsnorm_clip_(x)
1133
+ return _compilable_rmsnorm_clip_(x, clip_at)
1130
1134
 
1131
1135
 
1132
- def mu_law_compress(x, mu=127.0):
1136
+ def rmsnorm_normalize_(x, clip_at: float = 1e-6):
1137
+ x = list_guard(x)
1138
+ return _compilable_rmsnorm_clip_(x, clip_at)
1139
+
1140
+
1141
+ @decorator_knowngood
1142
+ def _compilable_mu_law_compress_(x, mu):
1133
1143
  """
1134
- Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1144
+ original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1145
+ """
1146
+
1147
+ for x_ in x:
1148
+ xa = promote(x_.abs()) * mu
1149
+ xa = xa.log1p()
1150
+ xa = xa / math.log1p(mu)
1151
+ xa = xa.copysign(x_)
1152
+ copy_stochastic_(x_, xa)
1135
1153
 
1154
+
1155
+ def mu_law_compress(x, mu=127.0):
1156
+ """
1136
1157
  μ-law compression
1137
1158
  Args:
1138
1159
  x: Input tensor
1139
1160
  mu: Compression parameter (default 127.0 for behavior similar to trust_region=1.5)
1140
1161
  """
1141
- xa = torch._foreach_abs_(x)
1142
- torch._foreach_mul_(xa, mu)
1143
- torch._foreach_log1p_(xa)
1144
- torch._foreach_div_(xa, math.log1p(mu))
1145
- return [xa_.copysign_(x_) for x_, xa_ in zip(x, xa)]
1162
+ x = list_guard(x)
1163
+ mu = scalar_guard(mu, x[0])
1164
+ _compilable_mu_law_compress_(x, mu)
1165
+ return x
1146
1166
 
1147
1167
 
1148
- def a_law_compress(x, A=87.6):
1168
+ @decorator_knowngood
1169
+ def _compilable_a_law_compress_(x, A):
1170
+ """
1171
+ original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1149
1172
  """
1150
- Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
1173
+ for x_ in x:
1174
+ xa = promote(x_.abs()) * A
1175
+ xa = torch.where(xa < 1, xa, 1 + xa.log())
1176
+ xa = xa.copysign(x_)
1177
+ xa = xa * (1 / (1 + math.log(A)))
1178
+ copy_stochastic_(x_, xa)
1151
1179
 
1180
+
1181
+ def a_law_compress(x, A=87.6):
1182
+ """
1152
1183
  A-law compression
1153
1184
  Args:
1154
1185
  x: Input tensor
1155
1186
  A: Compression parameter (default 87.6 - European PCM standard)
1187
+ :param x:
1188
+ :param A:
1189
+ :return:
1156
1190
  """
1157
- xa = torch._foreach_abs(x)
1158
- torch._foreach_mul_(xa, A)
1159
- [torch.where(x_ < 1, x_, 1 + torch.log_(x_), out=x_) for x_ in xa]
1160
- [xa_.copysign(x_) for x_, xa_ in zip(x, xa)]
1161
- torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
1162
- return xa
1191
+ x = list_guard(x)
1192
+ A = scalar_guard(A, x[0])
1193
+ _compilable_a_law_compress_(x, A)
1194
+ return x
1163
1195
 
1164
1196
 
1165
1197
  def identity(x):
@@ -1167,24 +1199,24 @@ def identity(x):
1167
1199
 
1168
1200
 
1169
1201
  @decorator_knowngood
1170
- def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1202
+ def _compilable_trust_region_clip_(grad, lerp, scale):
1171
1203
  # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
1172
- g32 = list(map(promote, grad))
1173
- [g.mul_(1 / scale) for g in g32]
1174
- tanh = torch._foreach_tanh(g32)
1175
- torch._foreach_abs_(g32)
1176
- torch._foreach_log1p_(g32)
1177
- [g.copysign_(t).lerp_(t, lerp).mul_(scale) for t, g in zip(tanh, g32)]
1178
-
1179
- torch._foreach_maximum_(g32, -2)
1180
- torch._foreach_minimum_(g32, 2)
1181
- return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
1204
+ for x_ in grad:
1205
+ x = promote(x_)
1206
+ x = x / scale
1207
+ tanh = x.tanh()
1208
+ x = x.abs().log1p()
1209
+ x = x.copysign(tanh) * (1 - lerp) + tanh * lerp
1210
+ x = x * scale
1211
+ x = x.clamp(min=-2, max=2)
1212
+ copy_stochastic_(x_, x)
1182
1213
 
1183
1214
 
1184
1215
  def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1185
1216
  grad = list_guard(grad)
1186
1217
  lerp, scale = scalar_guard(lerp, scale, grad[0])
1187
- return _compilable_trust_region_clip_(grad, lerp, scale)
1218
+ _compilable_trust_region_clip_(grad, lerp, scale)
1219
+ return grad
1188
1220
 
1189
1221
 
1190
1222
  @decorator
@@ -1262,6 +1294,7 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
1262
1294
 
1263
1295
 
1264
1296
  def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1297
+
1265
1298
  lr = scalar_guard(lr, param[0])
1266
1299
  _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
1267
1300
 
@@ -1277,7 +1310,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1277
1310
 
1278
1311
  @decorator_knowngood
1279
1312
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1280
- precond = psgd_precond_grad(expr, grad, *preconds)
1313
+ precond = psgd_precond_grad(expr, ea, *preconds)
1281
1314
  update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1282
1315
 
1283
1316
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.3.0',
13
+ version='1.4.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes