heavyball 0.24.4__tar.gz → 0.25.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {heavyball-0.24.4 → heavyball-0.25.0}/PKG-INFO +1 -1
  2. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/utils.py +10 -7
  3. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-0.24.4 → heavyball-0.25.0}/setup.py +1 -1
  5. {heavyball-0.24.4 → heavyball-0.25.0}/LICENSE +0 -0
  6. {heavyball-0.24.4 → heavyball-0.25.0}/README.md +0 -0
  7. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/__init__.py +0 -0
  8. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
  9. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/cached_psgd_kron.py +0 -0
  10. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/delayed_psgd.py +0 -0
  11. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_adamw.py +0 -0
  12. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_adopt.py +0 -0
  13. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_laprop.py +0 -0
  14. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_sfadamw.py +0 -0
  15. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_soap.py +0 -0
  16. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/p_adam.py +0 -0
  17. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/palm_foreach_sfadamw.py +0 -0
  18. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/palm_foreach_soap.py +0 -0
  19. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
  20. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  21. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
  22. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/psgd_kron.py +0 -0
  23. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/pure_psgd.py +0 -0
  24. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/SOURCES.txt +0 -0
  26. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.24.4 → heavyball-0.25.0}/setup.cfg +0 -0
  30. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_bf16_params.py +0 -0
  31. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_bf16_storage.py +0 -0
  33. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_caution.py +0 -0
  34. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_channels_last.py +0 -0
  35. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_closure.py +0 -0
  36. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_ema.py +0 -0
  37. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_foreach.py +0 -0
  38. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_hook.py +0 -0
  39. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_mars.py +0 -0
  40. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_memory.py +0 -0
  41. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_merge.py +0 -0
  42. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_no_grad.py +0 -0
  43. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_psgd.py +0 -0
  44. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_soap.py +0 -0
  45. {heavyball-0.24.4 → heavyball-0.25.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.4
3
+ Version: 0.25.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -492,22 +492,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
492
492
  super().__init__(params, {**defaults, 'foreach': foreach})
493
493
  self.fake_groups = {}
494
494
  self.use_ema = use_ema
495
-
496
- def key(self, param: Tensor):
497
- return (param.data_ptr(), tuple(param.shape))
495
+ self.mapping = {}
498
496
 
499
497
  def get_groups(self, group):
500
498
  if group['foreach']:
501
499
  return [group]
502
500
 
503
501
  for p in group['params']:
504
- if self.key(p) not in self.fake_groups:
505
- self.fake_groups[self.key(p)] = {**group, 'params': [p]}
502
+ if p not in self.fake_groups:
503
+ self.fake_groups[p] = {**group, 'params': [p]}
506
504
 
507
- return [self.fake_groups[self.key(p)] for p in group['params']]
505
+ return [self.fake_groups[p] for p in group['params']]
508
506
 
509
507
  def state_(self, arg: Tensor):
510
- return self.state[self.key(arg)]
508
+ return self.state[self.mapping.get(arg, arg)]
511
509
 
512
510
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
513
511
  for p, g in zip(p_list, g_list):
@@ -538,6 +536,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
538
536
  p_views = merge_group(group, p)
539
537
  if grad is not None:
540
538
  grad = merge_group(group, grad)
539
+ for i, pv in enumerate(p_views):
540
+ self.mapping[pv] = (p, i)
541
541
  if isinstance(p_views, Tensor):
542
542
  yield p_views, grad
543
543
  continue
@@ -622,11 +622,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
622
622
  for top_group in self.param_groups:
623
623
  for group in self.get_groups(top_group):
624
624
  self._step(group)
625
+ self.mapping.clear()
625
626
  if self.use_ema:
626
627
  self.ema_update(group)
628
+
627
629
  return loss
628
630
 
629
631
 
632
+
630
633
  class ScheduleFree(StatefulOptimizer):
631
634
  def eval(self):
632
635
  for group in self.param_groups:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.4
3
+ Version: 0.25.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.24.4',
13
+ version='0.25.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes