heavyball 0.21.0__py3-none-any.whl → 0.21.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,9 +7,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
7
7
  from typing import Optional
8
8
 
9
9
  import torch
10
+ from heavyball.utils import min_dtype, precond_grad_cached_
10
11
 
11
12
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
14
 
14
15
 
15
16
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -59,7 +60,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
59
60
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
60
61
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
61
62
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
62
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
63
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
64
+ storage_dtype=storage_dtype)
63
65
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
66
 
65
67
  def _step(self, group):
@@ -118,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
118
120
  q_orig = Q_list.pop(0)
119
121
  ea = exp_avg_list.pop(0)
120
122
 
121
- new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
123
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
122
124
 
123
125
  if should_update:
124
126
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -130,5 +132,3 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
130
132
  torch.matmul(q_.T.conj(), q_, out=c_)
131
133
  else:
132
134
  torch.mul(q_.conj(), q_, out=c_)
133
-
134
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -128,5 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
128
128
  else:
129
129
  torch.mul(q_.conj(), q_, out=c_)
130
130
 
131
- g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
132
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
131
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
heavyball/delayed_psgd.py CHANGED
@@ -5,12 +5,18 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import stochastic_lerp_, beta_debias
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
13
13
 
14
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
15
+ def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay):
16
+ new = psgd_precond_grad(q, exprs, ea)
17
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
18
+
19
+
14
20
  class ForeachDelayedPSGD(PSGDBase):
15
21
  """
16
22
  Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP)
@@ -39,7 +45,7 @@ class ForeachDelayedPSGD(PSGDBase):
39
45
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
46
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
47
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
42
- q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
48
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
43
49
  # expert parameters
44
50
  precond_init_scale=1.0, precond_lr=0.1):
45
51
  if not 0.0 <= lr:
@@ -98,15 +104,15 @@ class ForeachDelayedPSGD(PSGDBase):
98
104
  stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
99
105
 
100
106
  lr = -warmup(lr, group['step'], group['warmup_steps'])
107
+ lr = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(lr)
101
108
 
102
109
  Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
103
110
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
104
111
  q_orig = Q_list.pop(0)
105
112
  ea = exp_avg_list.pop(0)
106
113
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
107
- new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
114
+ _compilable_psgd_precond_grad_(q, state["exprs"], ea, p, lr, weight_decay)
108
115
  if should_update:
109
116
  q32 = [promote(q_) for q_ in q]
110
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
111
118
  store_triu_as_line)
112
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
heavyball/utils.py CHANGED
@@ -965,6 +965,20 @@ class PSGDBase(StatefulOptimizer):
965
965
  psgd_balance_Q(q)
966
966
 
967
967
 
968
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
969
+ def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
970
+ md = min_dtype(cached_q + [ea])
971
+ new = torch.einsum(self.state_(p)['cache_expr'], *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
+ update_param_([param], self.clip_fn([new]), lr, weight_decay)
973
+
974
+
975
+ def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
+ weight_decay: float):
977
+ if isinstance(lr, float):
978
+ lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
+ _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
980
+
981
+
968
982
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
969
983
  """Anneal preconditioner update probability during beginning of training.
970
984
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.0
3
+ Version: 0.21.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,7 +1,7 @@
1
1
  heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=apVzESMaQ8uxunHvfvYfyWA8HLbS25wQSd3j_YNEjGs,6603
3
- heavyball/cached_psgd_kron.py,sha256=3IETfsC0Ufu_8TPfo9SByGmztwjW6ktSFPwHNrUWkys,6601
4
- heavyball/delayed_psgd.py,sha256=0LaazbiBZOdx78EDS-945cW3bmeORjUvdFOGqdw3aMs,5631
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=jBMA5MI4EpOIKww0iQ0zkRO4g1EPPvWHpNWS5ndgDo8,6660
3
+ heavyball/cached_psgd_kron.py,sha256=5qENxAs2N0YRWI0CHFK30bdffxVjOagfAonODKnkDnI,6578
4
+ heavyball/delayed_psgd.py,sha256=b0Fd-3IoqC5Q0eSCGUGkOmKOPHsQke_Ozv4uWVEAZoU,5929
5
5
  heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
6
  heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
7
  heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
16
16
  heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
17
  heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=H8RsADNAXVbjQ9wWstYIKkXMq9E81aUF1j-2wfCeSLA,36471
20
- heavyball-0.21.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.21.0.dist-info/METADATA,sha256=hbXhr4XcPAkgfW8hpgoPRPrUoKeTCTvhZdofj4h8_8c,11926
22
- heavyball-0.21.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.21.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.21.0.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=9PaYhieWXIo6uE-Bb48N45wOg-AmQEkmNnTXjtKByfo,37211
20
+ heavyball-0.21.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.21.2.dist-info/METADATA,sha256=XvA88WVYQdXXj04iYuPDUWM2rKiUZInJkNvC5BM-zqA,11926
22
+ heavyball-0.21.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.21.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.21.2.dist-info/RECORD,,