heavyball 0.20.1__py3-none-any.whl → 0.21.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,9 +7,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
7
7
  from typing import Optional
8
8
 
9
9
  import torch
10
+ from heavyball.utils import min_dtype, precond_grad_cached_
10
11
 
11
12
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
13
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
14
 
14
15
 
15
16
  class ForeachCachedDelayedPSGDKron(PSGDBase):
@@ -59,7 +60,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
59
60
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
60
61
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
61
62
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
62
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
63
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
64
+ storage_dtype=storage_dtype)
63
65
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
66
 
65
67
  def _step(self, group):
@@ -118,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
118
120
  q_orig = Q_list.pop(0)
119
121
  ea = exp_avg_list.pop(0)
120
122
 
121
- new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
123
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
122
124
 
123
125
  if should_update:
124
126
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -130,5 +132,3 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
130
132
  torch.matmul(q_.T.conj(), q_, out=c_)
131
133
  else:
132
134
  torch.mul(q_.conj(), q_, out=c_)
133
-
134
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
- line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
12
+ line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -128,5 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
128
128
  else:
129
129
  torch.mul(q_.conj(), q_, out=c_)
130
130
 
131
- g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
132
- update_param_([p], self.clip_fn([g]), lr, weight_decay)
131
+ precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], ea, p, lr, weight_decay)
heavyball/delayed_psgd.py CHANGED
@@ -5,8 +5,8 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
-
9
8
  from heavyball.utils import stochastic_lerp_, beta_debias
9
+
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
11
  split_p_and_g_in_group, triu_to_line, line_to_triu, promote
12
12
 
@@ -39,7 +39,7 @@ class ForeachDelayedPSGD(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
42
- q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
42
+ q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32', #
43
43
  # expert parameters
44
44
  precond_init_scale=1.0, precond_lr=0.1):
45
45
  if not 0.0 <= lr:
@@ -105,8 +105,8 @@ class ForeachDelayedPSGD(PSGDBase):
105
105
  ea = exp_avg_list.pop(0)
106
106
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
107
107
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
108
+ update_param_([p], self.clip_fn([new]), lr, weight_decay)
108
109
  if should_update:
109
110
  q32 = [promote(q_) for q_ in q]
110
111
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
111
112
  store_triu_as_line)
112
- update_param_([p], self.clip_fn([new]), lr, weight_decay)
heavyball/utils.py CHANGED
@@ -329,23 +329,33 @@ def get_orthogonal_matrix(mat):
329
329
 
330
330
 
331
331
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
332
- def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
- x32 = [promote(x_) for x_ in x]
334
- y32 = [promote(y_) for y_ in y]
332
+ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
+ for x_, y_ in zip(x, y):
334
+ x32 = promote(x_)
335
+ y32 = promote(y_)
336
+ x32.lerp_(y32, a)
337
+ copy_stochastic_(x_, x32)
335
338
 
336
- torch._foreach_lerp_(x32, y32, a)
337
339
 
338
- copy_stochastic_list_(x, x32)
340
+ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
341
+ if not isinstance(a, torch.Tensor):
342
+ a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
343
+ _compilable_stochastic_lerp_(x, y, a)
339
344
 
340
345
 
341
346
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
342
- def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
343
- x32 = [promote(x_) for x_ in x]
344
- y32 = [promote(y_) for y_ in y]
347
+ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
348
+ for x_, y_ in zip(x, y):
349
+ x32 = promote(x_)
350
+ y32 = promote(y_)
351
+ x32.add_(y32, alpha=alpha)
352
+ copy_stochastic_(x_, x32)
345
353
 
346
- [x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
347
354
 
348
- copy_stochastic_list_(x, x32)
355
+ def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
356
+ if not isinstance(alpha, torch.Tensor):
357
+ alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
358
+ _compilable_stochastic_add_(x, y, alpha)
349
359
 
350
360
 
351
361
  @decorator
@@ -572,7 +582,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
572
582
 
573
583
 
574
584
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
575
- def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
585
+ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
576
586
  beta1 = beta_debias(beta1, step)
577
587
  beta2 = beta_debias(beta2, step)
578
588
 
@@ -585,6 +595,18 @@ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
585
595
  return denom
586
596
 
587
597
 
598
+ def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
599
+ grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
600
+ if isinstance(beta1, float):
601
+ beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
602
+ if isinstance(beta2, float):
603
+ beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
604
+ if isinstance(step, int):
605
+ step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
606
+ denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
607
+ return denom
608
+
609
+
588
610
  # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
589
611
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
590
612
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
@@ -734,7 +756,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
734
756
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
735
757
  order = G.dim()
736
758
  p = list(range(order))
737
- conjB = torch.randn_like(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
759
+ conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
738
760
  Q = [promote(q) for q in Q]
739
761
  for i, q in enumerate(Q):
740
762
  if q.dim() <= 1:
@@ -756,17 +778,17 @@ def psgd_lb(A, max_abs):
756
778
  a0 = torch.einsum('ij,ij->j', A, A)
757
779
  i = torch.argmax(a0)
758
780
 
759
- x = torch.index_select(a, 1, i).flatten().contiguous()
781
+ x = torch.index_select(A, 1, i).flatten().contiguous()
760
782
 
761
- x = torch.einsum('i,ij->j', x_, a)
783
+ x = torch.einsum('i,ij->j', x, A)
762
784
  x /= x.norm()
763
- x = torch.einsum('j,kj->k', x_, a)
785
+ x = torch.einsum('j,kj->k', x, A)
764
786
  x = x.norm()
765
787
  x *= max_abs
766
788
  return x
767
789
 
768
790
 
769
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
791
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
770
792
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
771
793
  """Update Kronecker product preconditioner Q with pair (V, G)."""
772
794
  exprA, exprGs, _ = exprs
@@ -799,7 +821,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
799
821
  stochastic_add_([o], [term1], -1)
800
822
 
801
823
 
802
- @decorator
824
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
803
825
  def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
804
826
  """Precondition gradient G with preconditioner Q."""
805
827
  md = min_dtype(Q)
@@ -943,6 +965,20 @@ class PSGDBase(StatefulOptimizer):
943
965
  psgd_balance_Q(q)
944
966
 
945
967
 
968
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
969
+ def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay):
970
+ md = min_dtype(cached_q + [ea])
971
+ new = torch.einsum(self.state_(p)['cache_expr'], *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
+ update_param_([param], self.clip_fn([new]), lr, weight_decay)
973
+
974
+
975
+ def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
+ weight_decay: float):
977
+ if isinstance(lr, float):
978
+ lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
+ _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay)
980
+
981
+
946
982
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
947
983
  """Anneal preconditioner update probability during beginning of training.
948
984
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.20.1
3
+ Version: 0.21.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -1,7 +1,7 @@
1
1
  heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=apVzESMaQ8uxunHvfvYfyWA8HLbS25wQSd3j_YNEjGs,6603
3
- heavyball/cached_psgd_kron.py,sha256=3IETfsC0Ufu_8TPfo9SByGmztwjW6ktSFPwHNrUWkys,6601
4
- heavyball/delayed_psgd.py,sha256=0LaazbiBZOdx78EDS-945cW3bmeORjUvdFOGqdw3aMs,5631
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=QMSQ1es2ViAf6KWEuePihyjqWfTN22-RNYve31CCZdI,6664
3
+ heavyball/cached_psgd_kron.py,sha256=NIWDUUrocU8rNR5HsIyR28XtxZ3OX7uKa25-J21YCJk,6582
4
+ heavyball/delayed_psgd.py,sha256=I1EhQU7CUBLY94fXIdrzLtgIpfnQHU5XWaKf0S2Ls34,5635
5
5
  heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
6
  heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
7
  heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy
16
16
  heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
17
  heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=14vt4r_MeTsp1q3m0lpgF-Q3PCJg6GLGJrhjRxnbWwQ,35174
20
- heavyball-0.20.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.20.1.dist-info/METADATA,sha256=qzF2P7e2EREeTy_4h85tvUY53omjNm32z83CUHTqt3U,11926
22
- heavyball-0.20.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.20.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.20.1.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=9PaYhieWXIo6uE-Bb48N45wOg-AmQEkmNnTXjtKByfo,37211
20
+ heavyball-0.21.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.21.1.dist-info/METADATA,sha256=NzvyCgg7cX4R_ouN9uQ1wE4KC6O2aQGwZLSRYMCSR7Q,11926
22
+ heavyball-0.21.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.21.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.21.1.dist-info/RECORD,,