heavyball 0.18.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -113,6 +113,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
113
113
  q_orig = Q_list.pop(0)
114
114
  ea = exp_avg_list.pop(0)
115
115
 
116
+ new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
117
+
116
118
  if self.should_update(group):
117
119
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
118
120
  q32 = [promote(q_) for q_ in q]
@@ -124,7 +126,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
124
126
  else:
125
127
  torch.mul(q_.conj(), q_, out=c_)
126
128
 
127
- set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
129
+ set_(g, new)
128
130
  grad_list = self.clip_fn(grad_list)
129
131
 
130
132
  lr = -warmup(lr, group['step'], group['warmup_steps'])
@@ -111,8 +111,6 @@ class ForeachCachedPSGDKron(PSGDBase):
111
111
  q_orig = Q_list.pop(0)
112
112
  ea = exp_avg_list.pop(0)
113
113
 
114
- new = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
115
-
116
114
  if self.should_update(group):
117
115
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
118
116
  q32 = [promote(q_) for q_ in q]
@@ -124,7 +122,7 @@ class ForeachCachedPSGDKron(PSGDBase):
124
122
  else:
125
123
  torch.mul(q_.conj(), q_, out=c_)
126
124
 
127
- set_(g, new)
125
+ set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
128
126
 
129
127
  grad_list = self.clip_fn(grad_list)
130
128
 
heavyball/utils.py CHANGED
@@ -817,14 +817,15 @@ class PSGDBase(StatefulOptimizer):
817
817
 
818
818
  def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
819
819
  store_triu_as_line=False):
820
- if self.should_update(group, self.balance_probability, 'balance_prob'):
821
- for g, q in zip(grad_list, q_list):
822
- if g.dim() > 1:
823
- psgd_balance_Q(q)
824
-
825
820
  for i, (p, grad, Q) in enumerate(zip(p_list, grad_list, q_list)):
826
821
  psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
827
- if original_q:
822
+
823
+ for g, q in zip(grad_list, q_list):
824
+ if g.dim() > 1:
825
+ psgd_balance_Q(q)
826
+
827
+ if original_q:
828
+ for q in q_list:
828
829
  if store_triu_as_line:
829
830
  update_triu_(original_q[i], Q)
830
831
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.1
3
+ Version: 0.18.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,6 +1,6 @@
1
1
  heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=IXGaMBqHidw7FikWjpc1yiuCLZrbCLi45rRvL2OfKxU,6399
3
- heavyball/cached_psgd_kron.py,sha256=rxrVaXYoivdAME6R8xoWECsIZ1DL8mfAw0YSu94pCj4,6402
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=t8XXsl91lINY0iB5Cn5aQDyjLxN2itGiA97Pur4mEkY,6422
3
+ heavyball/cached_psgd_kron.py,sha256=NnzgJB11xPfi5NrHl3OsQkgs-fxeR_tsHdfXeDiXxbE,6379
4
4
  heavyball/delayed_psgd.py,sha256=ylLNHglvjnkYAmJwcl1TtPA4PXKPaOv1YHVt0JVabMA,5551
5
5
  heavyball/foreach_adamw.py,sha256=h_ar0ZZRM0q_wxuEkxEOEYe-2p-mB4OMgAHivrUnPl8,1777
6
6
  heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
16
16
  heavyball/psgd_kron.py,sha256=KhZnV5MpigAEfJfvYI7ApF1GQ8ZWWXl7g5nYueWKYDQ,5438
17
17
  heavyball/pure_psgd.py,sha256=qPQ46pp7DWyQ1afBin2bqFVhaRhjt7RjXm6VuM_2sxg,4851
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
19
- heavyball/utils.py,sha256=sojDNo94l-jAPDSnhEM5EI6D83SHG_hTnJOkdnDquSI,30492
20
- heavyball-0.18.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.18.1.dist-info/METADATA,sha256=PeclJ6CUn0l-ycvQoW4uN-BTUml3xXdiPitBMidjryA,11810
22
- heavyball-0.18.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.18.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.18.1.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=LTkqF5hP6z9De4_jDLF0HQYUwz1MkJEOTdMOYyH5D0k,30426
20
+ heavyball-0.18.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.18.2.dist-info/METADATA,sha256=H074F1jIGrFvWdTdeRmhGXoxJ5V0zH7lIEhZ-LSP3Mc,11810
22
+ heavyball-0.18.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.18.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.18.2.dist-info/RECORD,,