heavyball 1.2.3__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.2.3 → heavyball-1.3.0}/PKG-INFO +1 -1
  2. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball/chainable.py +8 -7
  3. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball/utils.py +6 -7
  4. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.2.3 → heavyball-1.3.0}/setup.py +1 -1
  6. {heavyball-1.2.3 → heavyball-1.3.0}/LICENSE +0 -0
  7. {heavyball-1.2.3 → heavyball-1.3.0}/README.md +0 -0
  8. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball/__init__.py +0 -0
  9. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.2.3 → heavyball-1.3.0}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.2.3 → heavyball-1.3.0}/setup.cfg +0 -0
  14. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_caution.py +0 -0
  18. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_channels_last.py +0 -0
  19. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_closure.py +0 -0
  20. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_ema.py +0 -0
  21. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_foreach.py +0 -0
  22. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_hook.py +0 -0
  23. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_mars.py +0 -0
  24. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_memory.py +0 -0
  25. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_merge.py +0 -0
  26. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_no_grad.py +0 -0
  27. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_psgd.py +0 -0
  28. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_soap.py +0 -0
  29. {heavyball-1.2.3 → heavyball-1.3.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.3
3
+ Version: 1.3.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -438,14 +438,15 @@ class ChainOpt(utils.StatefulOptimizer):
438
438
 
439
439
  for param in p:
440
440
  state = self.state_(param)
441
- if 'step' not in state:
442
- if self.compile_step:
443
- step = utils.scalar_guard(0, param)
444
- state['step'] = step
445
- step = state['step'].add_(1)
441
+ if 'step' in state:
442
+ step = state['step']
443
+ elif self.compile_step:
444
+ step = utils.scalar_guard(0, param)
445
+ else:
446
+ step = 0
446
447
  break
447
448
 
448
- group['step'] = step
449
+ group['step'] = state['step'] = step = step + 1
449
450
 
450
451
  if group['warmup_steps'] and step < group['warmup_steps']:
451
452
  group['lr'] = group['base_lr'] * step / group['warmup_steps']
@@ -494,7 +495,7 @@ class BaseOpt(ChainOpt):
494
495
  auto_fuse: bool = True
495
496
 
496
497
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
497
- palm: bool = use_default, compile_step: bool = use_default, *fns):
498
+ palm: bool = use_default, *fns, compile_step: bool = use_default):
498
499
  if default(update_clipping, self.update_clipping) is None:
499
500
  if fns and self.auto_fuse:
500
501
  args, kwargs = None, None
@@ -88,7 +88,7 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
88
88
  except ZeroDivisionError:
89
89
  ckp1 = 0
90
90
 
91
- update, parameters, z = list_guard(update, parameters, z)
91
+ update, parameters, z, grad = list_guard(update, parameters, z, grad)
92
92
  lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
93
93
  _compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
94
94
  return weight_sum
@@ -912,13 +912,12 @@ def copy_stochastic_(target: Tensor, source: Tensor):
912
912
  @decorator_knowngood
913
913
  def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
914
914
  g: List[Optional[Tensor]]):
915
- u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
916
- p32, u32 = [list(map(promote, x)) for x in [p, u]]
917
-
918
- for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
915
+ for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
916
+ u_ = promote(u_.view_as(p_))
917
+ p32_ = promote(p_)
919
918
  if caution:
920
- u32_ = _compilable_cautioning(promote(g_), u32_)
921
- p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
919
+ u_ = _compilable_cautioning(promote(g_), u_)
920
+ p32_ = p32_ * (1 - decay * lr) + u_ * -lr
922
921
  copy_stochastic_(p_, p32_)
923
922
 
924
923
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.3
3
+ Version: 1.3.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.2.3',
13
+ version='1.3.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes