heavyball 0.18.3__tar.gz → 0.18.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.18.3 → heavyball-0.18.4}/PKG-INFO +1 -1
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/utils.py +11 -10
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.18.3 → heavyball-0.18.4}/setup.py +1 -1
- {heavyball-0.18.3 → heavyball-0.18.4}/LICENSE +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/README.md +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/__init__.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/p_adam.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/setup.cfg +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_bf16_q.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_closure.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_foreach.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_memory.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_merge.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_no_grad.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_psgd.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_soap.py +0 -0
- {heavyball-0.18.3 → heavyball-0.18.4}/test/test_stochastic_updates.py +0 -0
@@ -479,7 +479,7 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
479
479
|
copy_stochastic_(t, s)
|
480
480
|
|
481
481
|
|
482
|
-
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=
|
482
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
483
483
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
484
484
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
485
485
|
# create a random 16 bit integer
|
@@ -815,20 +815,21 @@ class PSGDBase(StatefulOptimizer):
|
|
815
815
|
|
816
816
|
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
|
817
817
|
store_triu_as_line=False):
|
818
|
-
|
818
|
+
if original_q:
|
819
|
+
if store_triu_as_line:
|
820
|
+
update_fn = update_triu_
|
821
|
+
else:
|
822
|
+
update_fn = copy_stochastic_list_
|
823
|
+
else:
|
824
|
+
update_fn = lambda x, y: None
|
825
|
+
for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
|
819
826
|
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
827
|
+
update_fn(oq, Q)
|
820
828
|
|
821
|
-
for g, q in zip(grad_list, q_list):
|
829
|
+
for g, q in zip(grad_list, original_q if original_q else q_list):
|
822
830
|
if g.dim() > 1:
|
823
831
|
psgd_balance_Q(q)
|
824
832
|
|
825
|
-
if original_q:
|
826
|
-
for q in q_list:
|
827
|
-
if store_triu_as_line:
|
828
|
-
update_triu_(original_q[i], Q)
|
829
|
-
else:
|
830
|
-
copy_stochastic_list_(original_q[i], Q)
|
831
|
-
|
832
833
|
|
833
834
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
834
835
|
"""Anneal preconditioner update probability during beginning of training.
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|