heavyball 0.21.1__tar.gz → 0.21.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.21.1 → heavyball-0.21.3}/PKG-INFO +1 -1
  2. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/cached_delayed_psgd_kron.py +1 -1
  3. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/cached_psgd_kron.py +1 -1
  4. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/delayed_psgd.py +8 -2
  5. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/utils.py +1 -1
  6. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball.egg-info/PKG-INFO +1 -1
  7. {heavyball-0.21.1 → heavyball-0.21.3}/setup.py +1 -1
  8. {heavyball-0.21.1 → heavyball-0.21.3}/LICENSE +0 -0
  9. {heavyball-0.21.1 → heavyball-0.21.3}/README.md +0 -0
  10. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/__init__.py +0 -0
  11. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/foreach_adamw.py +0 -0
  12. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/foreach_adopt.py +0 -0
  13. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/foreach_laprop.py +0 -0
  14. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/foreach_sfadamw.py +0 -0
  15. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/foreach_soap.py +0 -0
  16. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/p_adam.py +0 -0
  17. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/palm_foreach_sfadamw.py +0 -0
  18. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/palm_foreach_soap.py +0 -0
  19. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/precond_schedule_foreach_soap.py +0 -0
  20. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  21. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/precond_schedule_sfpsoap.py +0 -0
  22. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/psgd_kron.py +0 -0
  23. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/pure_psgd.py +0 -0
  24. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball.egg-info/SOURCES.txt +0 -0
  26. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.21.1 → heavyball-0.21.3}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.21.1 → heavyball-0.21.3}/setup.cfg +0 -0
  30. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_bf16_params.py +0 -0
  31. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_bf16_storage.py +0 -0
  33. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_closure.py +0 -0
  34. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_ema.py +0 -0
  35. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_foreach.py +0 -0
  36. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_memory.py +0 -0
  37. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_merge.py +0 -0
  38. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_no_grad.py +0 -0
  39. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_psgd.py +0 -0
  40. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_soap.py +0 -0
  41. {heavyball-0.21.1 → heavyball-0.21.3}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.1
3
+ Version: 0.21.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -120,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
120
120
  q_orig = Q_list.pop(0)
121
121
  ea = exp_avg_list.pop(0)
122
122
 
123
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
123
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
124
124
 
125
125
  if should_update:
126
126
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -128,4 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
128
128
  else:
129
129
  torch.mul(q_.conj(), q_, out=c_)
130
130
 
131
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
131
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
@@ -11,6 +11,12 @@ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust
11
11
  split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
13
13
 
14
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
15
+ def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay):
16
+ new = psgd_precond_grad(q, exprs, ea)
17
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
18
+
19
+
14
20
  class ForeachDelayedPSGD(PSGDBase):
15
21
  """
16
22
  Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP)
@@ -98,14 +104,14 @@ class ForeachDelayedPSGD(PSGDBase):
98
104
  stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
99
105
 
100
106
  lr = -warmup(lr, group['step'], group['warmup_steps'])
107
+ lr = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(lr)
101
108
 
102
109
  Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
103
110
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
104
111
  q_orig = Q_list.pop(0)
105
112
  ea = exp_avg_list.pop(0)
106
113
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
107
- new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
108
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
114
+ _compilable_psgd_precond_grad_(q, state["exprs"], ea, p, lr, weight_decay)
109
115
  if should_update:
110
116
  q32 = [promote(q_) for q_ in q]
111
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -968,7 +968,7 @@ class PSGDBase(StatefulOptimizer):
968
968
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
969
969
  def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
970
970
  md = min_dtype(cached_q + [ea])
971
- new = torch.einsum(self.state_(p)['cache_expr'], *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
971
+ new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
972
  update_param_([param], self.clip_fn([new]), lr, weight_decay)
973
973
 
974
974
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.1
3
+ Version: 0.21.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.21.1',
13
+ version='0.21.3',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes