heavyball 0.21.0__tar.gz → 0.21.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {heavyball-0.21.0 → heavyball-0.21.1}/PKG-INFO +1 -1
  2. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/cached_delayed_psgd_kron.py +5 -5
  3. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/cached_psgd_kron.py +2 -3
  4. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/delayed_psgd.py +3 -3
  5. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/utils.py +14 -0
  6. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/PKG-INFO +1 -1
  7. {heavyball-0.21.0 → heavyball-0.21.1}/setup.py +1 -1
  8. {heavyball-0.21.0 → heavyball-0.21.1}/LICENSE +0 -0
  9. {heavyball-0.21.0 → heavyball-0.21.1}/README.md +0 -0
  10. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/__init__.py +0 -0
  11. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_adamw.py +0 -0
  12. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_adopt.py +0 -0
  13. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_laprop.py +0 -0
  14. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_sfadamw.py +0 -0
  15. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/foreach_soap.py +0 -0
  16. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/p_adam.py +0 -0
  17. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/palm_foreach_sfadamw.py +0 -0
  18. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/palm_foreach_soap.py +0 -0
  19. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
  20. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  21. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
  22. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/psgd_kron.py +0 -0
  23. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/pure_psgd.py +0 -0
  24. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  25. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/SOURCES.txt +0 -0
  26. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/dependency_links.txt +0 -0
  27. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/requires.txt +0 -0
  28. {heavyball-0.21.0 → heavyball-0.21.1}/heavyball.egg-info/top_level.txt +0 -0
  29. {heavyball-0.21.0 → heavyball-0.21.1}/setup.cfg +0 -0
  30. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_bf16_params.py +0 -0
  31. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_bf16_q.py +0 -0
  32. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_bf16_storage.py +0 -0
  33. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_closure.py +0 -0
  34. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_ema.py +0 -0
  35. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_foreach.py +0 -0
  36. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_memory.py +0 -0
  37. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_merge.py +0 -0
  38. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_no_grad.py +0 -0
  39. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_psgd.py +0 -0
  40. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_soap.py +0 -0
  41. {heavyball-0.21.0 → heavyball-0.21.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.0
3
+ Version: 0.21.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -7,9 +7,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
7
7
  from typing import Optional
8
8
 
9
9
  import torch
10
+ from heavyball.utils import min_dtype, precond_grad_cached_
10
11
 
11
12
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
14
 
14
15
 
15
16
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -59,7 +60,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
59
60
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
60
61
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
61
62
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
62
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
63
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
64
+ storage_dtype=storage_dtype)
63
65
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
66
 
65
67
  def _step(self, group):
@@ -118,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
118
120
  q_orig = Q_list.pop(0)
119
121
  ea = exp_avg_list.pop(0)
120
122
 
121
- new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
123
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
122
124
 
123
125
  if should_update:
124
126
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -130,5 +132,3 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
130
132
  torch.matmul(q_.T.conj(), q_, out=c_)
131
133
  else:
132
134
  torch.mul(q_.conj(), q_, out=c_)
133
-
134
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -128,5 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
128
128
  else:
129
129
  torch.mul(q_.conj(), q_, out=c_)
130
130
 
131
- g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
132
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
131
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import stochastic_lerp_, beta_debias
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
42
- q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
42
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -105,8 +105,8 @@ class ForeachDelayedPSGD(PSGDBase):
105
105
  ea = exp_avg_list.pop(0)
106
106
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
107
107
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
108
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
108
109
  if should_update:
109
110
  q32 = [promote(q_) for q_ in q]
110
111
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
111
112
  store_triu_as_line)
112
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -965,6 +965,20 @@ class PSGDBase(StatefulOptimizer):
965
965
  psgd_balance_Q(q)
966
966
 
967
967
 
968
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
969
+ def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
970
+ md = min_dtype(cached_q + [ea])
971
+ new = torch.einsum(self.state_(p)['cache_expr'], *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
+ update_param_([param], self.clip_fn([new]), lr, weight_decay)
973
+
974
+
975
+ def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
+ weight_decay: float):
977
+ if isinstance(lr, float):
978
+ lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
+ _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
980
+
981
+
968
982
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
969
983
  """Anneal preconditioner update probability during beginning of training.
970
984
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.0
3
+ Version: 0.21.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.21.0',
13
+ version='0.21.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes