heavyball 0.21.3__tar.gz → 0.21.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.21.3 → heavyball-0.21.5}/PKG-INFO +1 -1
  2. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/cached_delayed_psgd_kron.py +1 -1
  3. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/cached_psgd_kron.py +1 -1
  4. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/delayed_psgd.py +5 -5
  5. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/p_adam.py +2 -2
  6. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/utils.py +4 -4
  7. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball.egg-info/PKG-INFO +1 -1
  8. {heavyball-0.21.3 → heavyball-0.21.5}/setup.py +1 -1
  9. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_bf16_params.py +2 -0
  10. {heavyball-0.21.3 → heavyball-0.21.5}/LICENSE +0 -0
  11. {heavyball-0.21.3 → heavyball-0.21.5}/README.md +0 -0
  12. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/__init__.py +0 -0
  13. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/foreach_adamw.py +0 -0
  14. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/foreach_adopt.py +0 -0
  15. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/foreach_laprop.py +0 -0
  16. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/foreach_sfadamw.py +0 -0
  17. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/foreach_soap.py +0 -0
  18. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/palm_foreach_sfadamw.py +0 -0
  19. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/palm_foreach_soap.py +0 -0
  20. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/precond_schedule_foreach_soap.py +0 -0
  21. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  22. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/precond_schedule_sfpsoap.py +0 -0
  23. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/psgd_kron.py +0 -0
  24. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/pure_psgd.py +0 -0
  25. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  26. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball.egg-info/SOURCES.txt +0 -0
  27. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball.egg-info/dependency_links.txt +0 -0
  28. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball.egg-info/requires.txt +0 -0
  29. {heavyball-0.21.3 → heavyball-0.21.5}/heavyball.egg-info/top_level.txt +0 -0
  30. {heavyball-0.21.3 → heavyball-0.21.5}/setup.cfg +0 -0
  31. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_bf16_storage.py +0 -0
  33. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_closure.py +0 -0
  34. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_ema.py +0 -0
  35. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_foreach.py +0 -0
  36. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_memory.py +0 -0
  37. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_merge.py +0 -0
  38. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_no_grad.py +0 -0
  39. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_psgd.py +0 -0
  40. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_soap.py +0 -0
  41. {heavyball-0.21.3 → heavyball-0.21.5}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.3
3
+ Version: 0.21.5
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -120,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
120
120
  q_orig = Q_list.pop(0)
121
121
  ea = exp_avg_list.pop(0)
122
122
 
123
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
123
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
124
124
 
125
125
  if should_update:
126
126
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -128,4 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
128
128
  else:
129
129
  torch.mul(q_.conj(), q_, out=c_)
130
130
 
131
- precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
131
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn)
@@ -11,10 +11,10 @@ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust
11
11
  split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
13
13
 
14
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
15
- def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay):
14
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
15
+ def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn):
16
16
  new = psgd_precond_grad(q, exprs, ea)
17
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
17
+ update_param_([p], clip_fn([new]), lr, weight_decay)
18
18
 
19
19
 
20
20
  class ForeachDelayedPSGD(PSGDBase):
@@ -62,7 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
62
62
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
63
63
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
64
64
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
65
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
65
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
66
66
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
67
67
 
68
68
  def _step(self, group):
@@ -111,7 +111,7 @@ class ForeachDelayedPSGD(PSGDBase):
111
111
  q_orig = Q_list.pop(0)
112
112
  ea = exp_avg_list.pop(0)
113
113
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
114
- _compilable_psgd_precond_grad_(q, state["exprs"], ea, p, lr, weight_decay)
114
+ _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn)
115
115
  if should_update:
116
116
  q32 = [promote(q_) for q_ in q]
117
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -6,7 +6,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
6
6
 
7
7
  import torch
8
8
 
9
- from heavyball.utils import triu_to_line, line_to_triu, identity
9
+ from heavyball.utils import triu_to_line, line_to_triu, identity, stochastic_lerp_
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, exp_avg_sq_, beta_debias, \
11
11
  split_p_and_g_in_group, promote
12
12
 
@@ -100,7 +100,7 @@ class ForeachPaLMPAdam(PSGDBase):
100
100
  for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
101
101
  q32 = [promote(qq_) for qq_ in q_]
102
102
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
103
- torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
103
+ stochastic_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
104
104
 
105
105
  beta2 = 1 - group['step'] ** -group['beta2_scale']
106
106
 
@@ -966,17 +966,17 @@ class PSGDBase(StatefulOptimizer):
966
966
 
967
967
 
968
968
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
969
- def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
969
+ def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn):
970
970
  md = min_dtype(cached_q + [ea])
971
971
  new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
- update_param_([param], self.clip_fn([new]), lr, weight_decay)
972
+ update_param_([param], clip_fn([new]), lr, weight_decay)
973
973
 
974
974
 
975
975
  def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
- weight_decay: float):
976
+ weight_decay: float, clip_fn):
977
977
  if isinstance(lr, float):
978
978
  lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
- _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
979
+ _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn)
980
980
 
981
981
 
982
982
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.3
3
+ Version: 0.21.5
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.21.3',
13
+ version='0.21.5',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -22,6 +22,8 @@ def get_memory():
22
22
  @pytest.mark.parametrize("size,depth", [(256, 2)])
23
23
  def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 3):
24
24
  set_torch()
25
+ if 'psgd' not in opt.lower() and 'padam' not in opt.lower():
26
+ raise pytest.skip('Only PSGD and PaLMPAdam are supported')
25
27
  opt = getattr(heavyball, opt)
26
28
 
27
29
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes