heavyball 0.24.4__tar.gz → 0.25.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.24.4 → heavyball-0.25.0}/PKG-INFO +1 -1
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/utils.py +10 -7
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.24.4 → heavyball-0.25.0}/setup.py +1 -1
- {heavyball-0.24.4 → heavyball-0.25.0}/LICENSE +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/README.md +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/__init__.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/p_adam.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/setup.cfg +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_bf16_params.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_bf16_q.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_bf16_storage.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_caution.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_channels_last.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_closure.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_ema.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_foreach.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_hook.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_mars.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_memory.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_merge.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_no_grad.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_psgd.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_soap.py +0 -0
- {heavyball-0.24.4 → heavyball-0.25.0}/test/test_stochastic_updates.py +0 -0
@@ -492,22 +492,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
492
492
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
493
493
|
self.fake_groups = {}
|
494
494
|
self.use_ema = use_ema
|
495
|
-
|
496
|
-
def key(self, param: Tensor):
|
497
|
-
return (param.data_ptr(), tuple(param.shape))
|
495
|
+
self.mapping = {}
|
498
496
|
|
499
497
|
def get_groups(self, group):
|
500
498
|
if group['foreach']:
|
501
499
|
return [group]
|
502
500
|
|
503
501
|
for p in group['params']:
|
504
|
-
if
|
505
|
-
self.fake_groups[
|
502
|
+
if p not in self.fake_groups:
|
503
|
+
self.fake_groups[p] = {**group, 'params': [p]}
|
506
504
|
|
507
|
-
return [self.fake_groups[
|
505
|
+
return [self.fake_groups[p] for p in group['params']]
|
508
506
|
|
509
507
|
def state_(self, arg: Tensor):
|
510
|
-
return self.state[self.
|
508
|
+
return self.state[self.mapping.get(arg, arg)]
|
511
509
|
|
512
510
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
513
511
|
for p, g in zip(p_list, g_list):
|
@@ -538,6 +536,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
538
536
|
p_views = merge_group(group, p)
|
539
537
|
if grad is not None:
|
540
538
|
grad = merge_group(group, grad)
|
539
|
+
for i, pv in enumerate(p_views):
|
540
|
+
self.mapping[pv] = (p, i)
|
541
541
|
if isinstance(p_views, Tensor):
|
542
542
|
yield p_views, grad
|
543
543
|
continue
|
@@ -622,11 +622,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
622
622
|
for top_group in self.param_groups:
|
623
623
|
for group in self.get_groups(top_group):
|
624
624
|
self._step(group)
|
625
|
+
self.mapping.clear()
|
625
626
|
if self.use_ema:
|
626
627
|
self.ema_update(group)
|
628
|
+
|
627
629
|
return loss
|
628
630
|
|
629
631
|
|
632
|
+
|
630
633
|
class ScheduleFree(StatefulOptimizer):
|
631
634
|
def eval(self):
|
632
635
|
for group in self.param_groups:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|