heavyball 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
62
62
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
63
63
 
64
64
  def _step(self, group):
65
+ should_update = self.should_update(group)
65
66
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
66
67
  precond_init_scale = group['precond_init_scale']
67
68
  max_size_triangular = group['max_size_triangular']
@@ -115,7 +116,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
115
116
 
116
117
  new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
117
118
 
118
- if self.should_update(group):
119
+ if should_update:
119
120
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
120
121
  q32 = [promote(q_) for q_ in q]
121
122
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
@@ -8,8 +8,8 @@ from typing import Optional
8
8
 
9
9
  import torch
10
10
 
11
- from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base, promote
11
+ from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
12
+ line_to_triu, triu_to_line, set_, einsum_base, promote
13
13
 
14
14
 
15
15
  class ForeachCachedPSGDKron(PSGDBase):
@@ -71,6 +71,7 @@ class ForeachCachedPSGDKron(PSGDBase):
71
71
  beta = group['beta']
72
72
  store_triu_as_line = group['store_triu_as_line']
73
73
  q_dtype = getattr(torch, group['q_dtype'])
74
+ should_update = self.should_update(group)
74
75
 
75
76
  vals = []
76
77
 
@@ -111,7 +112,7 @@ class ForeachCachedPSGDKron(PSGDBase):
111
112
  q_orig = Q_list.pop(0)
112
113
  ea = exp_avg_list.pop(0)
113
114
 
114
- if self.should_update(group):
115
+ if should_update:
115
116
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
116
117
  q32 = [promote(q_) for q_ in q]
117
118
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
heavyball/delayed_psgd.py CHANGED
@@ -62,6 +62,7 @@ class ForeachDelayedPSGD(PSGDBase):
62
62
 
63
63
 
64
64
  def _step(self, group):
65
+ should_update = self.should_update(group)
65
66
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
66
67
  precond_init_scale = group['precond_init_scale']
67
68
  max_size_triangular = group['max_size_triangular']
@@ -103,7 +104,7 @@ class ForeachDelayedPSGD(PSGDBase):
103
104
  ea = exp_avg_list.pop(0)
104
105
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
105
106
  new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
106
- if self.should_update(group):
107
+ if should_update:
107
108
  q32 = [promote(q_) for q_ in q]
108
109
  self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig], store_triu_as_line)
109
110
  set_(g, new)
heavyball/p_adam.py CHANGED
@@ -61,6 +61,7 @@ class ForeachPaLMPAdam(PSGDBase):
61
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
62
62
 
63
63
  def _step(self, group):
64
+ should_update = self.should_update(group)
64
65
  precond_init_scale = group['precond_init_scale']
65
66
  max_size_triangular = group['max_size_triangular']
66
67
  min_ndim_triangular = group['min_ndim_triangular']
@@ -94,7 +95,7 @@ class ForeachPaLMPAdam(PSGDBase):
94
95
  group["step"] += 1
95
96
 
96
97
  Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
97
- if self.should_update(group):
98
+ if should_update:
98
99
  for g, p, q_, q_orig in zip(grad_list, p_list, Q_triu, Q_list):
99
100
  q32 = [promote(qq_) for qq_ in q_]
100
101
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
heavyball/psgd_kron.py CHANGED
@@ -60,6 +60,7 @@ class ForeachPSGDKron(PSGDBase):
60
60
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
61
 
62
62
  def _step(self, group):
63
+ should_update = self.should_update(group)
63
64
  momentum_into_precond_update = group.get("momentum_into_precond_update", True)
64
65
  precond_init_scale = group['precond_init_scale']
65
66
  max_size_triangular = group['max_size_triangular']
@@ -101,7 +102,7 @@ class ForeachPSGDKron(PSGDBase):
101
102
  ea = exp_avg_list.pop(0)
102
103
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
103
104
 
104
- if self.should_update(group):
105
+ if should_update:
105
106
  q32 = [promote(q_) for q_ in q]
106
107
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
107
108
  set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
heavyball/pure_psgd.py CHANGED
@@ -57,7 +57,7 @@ class ForeachPurePSGD(PSGDBase):
57
57
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
58
58
 
59
59
  def _step(self, group):
60
- # update preconditioners all together
60
+ should_update = self.should_update(group)
61
61
  precond_init_scale = group['precond_init_scale']
62
62
  max_size_triangular = group['max_size_triangular']
63
63
  min_ndim_triangular = group['min_ndim_triangular']
@@ -93,7 +93,7 @@ class ForeachPurePSGD(PSGDBase):
93
93
  q_orig = Q_list.pop(0)
94
94
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
95
95
 
96
- if self.should_update(group):
96
+ if group:
97
97
  q32 = [promote(q_) for q_ in q]
98
98
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
99
99
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
heavyball/utils.py CHANGED
@@ -335,9 +335,7 @@ def promote(x):
335
335
  def min_dtype(xs: List[torch.Tensor]):
336
336
  dtypes = [x.dtype for x in xs]
337
337
  for d in (torch.float32, torch.bfloat16, torch.float16):
338
- if all(d == x for x in dtypes):
339
- return d
340
- if all(d in (x, torch.float32, torch.float64) for x in dtypes):
338
+ if all(x in (d, torch.float32, torch.float64) for x in dtypes):
341
339
  return d
342
340
  return torch.float32
343
341
 
@@ -481,7 +479,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
481
479
  copy_stochastic_(t, s)
482
480
 
483
481
 
484
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
482
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
485
483
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
486
484
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
487
485
  # create a random 16 bit integer
@@ -817,20 +815,21 @@ class PSGDBase(StatefulOptimizer):
817
815
 
818
816
  def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
819
817
  store_triu_as_line=False):
820
- for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
818
+ if original_q:
819
+ if store_triu_as_line:
820
+ update_fn = update_triu_
821
+ else:
822
+ update_fn = copy_stochastic_list_
823
+ else:
824
+ update_fn = lambda x, y: None
825
+ for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
821
826
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
827
+ update_fn(oq, Q)
822
828
 
823
- for g, q in zip(grad_list, q_list):
829
+ for g, q in zip(grad_list, original_q if original_q else q_list):
824
830
  if g.dim() > 1:
825
831
  psgd_balance_Q(q)
826
832
 
827
- if original_q:
828
- for q in q_list:
829
- if store_triu_as_line:
830
- update_triu_(original_q[i], Q)
831
- else:
832
- copy_stochastic_list_(original_q[i], Q)
833
-
834
833
 
835
834
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
836
835
  """Anneal preconditioner update probability during beginning of training.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.2
3
+ Version: 0.18.4
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,24 +1,24 @@
1
1
  heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=t8XXsl91lINY0iB5Cn5aQDyjLxN2itGiA97Pur4mEkY,6422
3
- heavyball/cached_psgd_kron.py,sha256=NnzgJB11xPfi5NrHl3OsQkgs-fxeR_tsHdfXeDiXxbE,6379
4
- heavyball/delayed_psgd.py,sha256=ylLNHglvjnkYAmJwcl1TtPA4PXKPaOv1YHVt0JVabMA,5551
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=PQAER6UgVh5l87DGRZrJ8CVP9UhyCG5wJD9rPLnj_G8,6460
3
+ heavyball/cached_psgd_kron.py,sha256=GaeneBp0irksCSBIrJY4D_0hCpZ-uSRPMhqVX_a-og8,6417
4
+ heavyball/delayed_psgd.py,sha256=fhBWFLTSl1S2gHWCeYak-STaXRwpC56sWZGLFMKFEJM,5589
5
5
  heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
6
6
  heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
7
7
  heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
8
8
  heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
9
9
  heavyball/foreach_soap.py,sha256=h6ptMch7oaynvu3eIJtWnVXypDA_5JDVm3Zb3PNEma0,4634
10
- heavyball/p_adam.py,sha256=HnlOSR6fqOet0S4KavyU6zGtu7Tz7vUX18yJIgMnEBc,5845
10
+ heavyball/p_adam.py,sha256=4zJDGJrpgUyVzr3GiELETFre4xr3-PE10OuAZj-jFM8,5883
11
11
  heavyball/palm_foreach_sfadamw.py,sha256=yvZbPyjDW8qd3r4qDXb6uTr5RozQ7JSDj4aYYRnKGLA,2248
12
12
  heavyball/palm_foreach_soap.py,sha256=g4hbiGRcti-J-a0SwAkP4ii5pU-aalsZH5bssyhroLk,5938
13
13
  heavyball/precond_schedule_foreach_soap.py,sha256=WLg5SzpJnKPZUvFyIvdwSZa1Umt5cpr3Kow_42orM-E,4863
14
14
  heavyball/precond_schedule_palm_foreach_soap.py,sha256=ammQrvRZFF-wc-wEiPEoFhS_7b8pdV61QfcLoQfimSo,6211
15
15
  heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdojklrY8pA,6770
16
- heavyball/psgd_kron.py,sha256=KhZnV5MpigAEfJfvYI7ApF1GQ8ZWWXl7g5nYueWKYDQ,5438
17
- heavyball/pure_psgd.py,sha256=qPQ46pp7DWyQ1afBin2bqFVhaRhjt7RjXm6VuM_2sxg,4851
16
+ heavyball/psgd_kron.py,sha256=u46dorOUXx-do1IYeno2wj-6l1zYKMQQC-N2Zr2PzLI,5476
17
+ heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=LTkqF5hP6z9De4_jDLF0HQYUwz1MkJEOTdMOYyH5D0k,30426
20
- heavyball-0.18.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.18.2.dist-info/METADATA,sha256=H074F1jIGrFvWdTdeRmhGXoxJ5V0zH7lIEhZ-LSP3Mc,11810
22
- heavyball-0.18.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.18.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.18.2.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=U8BX11BAhUCH_JO_tnVy4JLSkfDSBHFi_a5s8Pvsf-s,30437
20
+ heavyball-0.18.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.18.4.dist-info/METADATA,sha256=kBGWYUMlF21_Cg2qtcHa4ajfzB3xlm5UUb24S3mUgKI,11810
22
+ heavyball-0.18.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.18.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.18.4.dist-info/RECORD,,