heavyball 0.24.3__tar.gz → 0.25.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.24.3 → heavyball-0.25.0}/PKG-INFO +1 -1
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/foreach_soap.py +2 -3
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/palm_foreach_soap.py +2 -3
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/precond_schedule_foreach_soap.py +2 -2
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/precond_schedule_palm_foreach_soap.py +2 -2
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/utils.py +10 -7
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.24.3 → heavyball-0.25.0}/setup.py +1 -1
- {heavyball-0.24.3 → heavyball-0.25.0}/LICENSE +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/README.md +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/__init__.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/delayed_psgd.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/p_adam.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/setup.cfg +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_bf16_params.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_bf16_q.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_bf16_storage.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_caution.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_channels_last.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_closure.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_ema.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_foreach.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_hook.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_mars.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_memory.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_merge.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_no_grad.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_psgd.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_soap.py +0 -0
- {heavyball-0.24.3 → heavyball-0.25.0}/test/test_stochastic_updates.py +0 -0
@@ -71,12 +71,11 @@ class ForeachSOAP(StatefulOptimizer):
|
|
71
71
|
# Decay the first and second moment running average coefficient
|
72
72
|
# In-place operations to update the averages at the same time
|
73
73
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
74
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
75
74
|
|
76
75
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
77
76
|
|
78
|
-
for p, g, ea,
|
79
|
-
|
77
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
78
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
80
79
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
81
80
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
82
81
|
exp_avg_projected = project(ea, state['Q'], False)
|
@@ -81,11 +81,10 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
81
81
|
# In-place operations to update the averages at the same time
|
82
82
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
83
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
84
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
86
85
|
|
87
|
-
for p, g, ea,
|
88
|
-
|
86
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
87
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
89
88
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
90
89
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
91
90
|
exp_avg_projected = project(ea, state['Q'], False)
|
@@ -73,12 +73,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
73
73
|
# Decay the first and second moment running average coefficient
|
74
74
|
# In-place operations to update the averages at the same time
|
75
75
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
76
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
77
76
|
|
78
77
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
79
78
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
80
79
|
|
81
|
-
for p, g, ea,
|
80
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
81
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
82
82
|
state = self.state_(p)
|
83
83
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
84
84
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -84,12 +84,12 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
84
84
|
# In-place operations to update the averages at the same time
|
85
85
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
86
86
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
87
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
88
87
|
|
89
88
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
90
89
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
90
|
|
92
|
-
for p, g, ea,
|
91
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
92
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
93
93
|
state = self.state_(p)
|
94
94
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
95
95
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -492,22 +492,20 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
492
492
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
493
493
|
self.fake_groups = {}
|
494
494
|
self.use_ema = use_ema
|
495
|
-
|
496
|
-
def key(self, param: Tensor):
|
497
|
-
return (param.data_ptr(), tuple(param.shape))
|
495
|
+
self.mapping = {}
|
498
496
|
|
499
497
|
def get_groups(self, group):
|
500
498
|
if group['foreach']:
|
501
499
|
return [group]
|
502
500
|
|
503
501
|
for p in group['params']:
|
504
|
-
if
|
505
|
-
self.fake_groups[
|
502
|
+
if p not in self.fake_groups:
|
503
|
+
self.fake_groups[p] = {**group, 'params': [p]}
|
506
504
|
|
507
|
-
return [self.fake_groups[
|
505
|
+
return [self.fake_groups[p] for p in group['params']]
|
508
506
|
|
509
507
|
def state_(self, arg: Tensor):
|
510
|
-
return self.state[self.
|
508
|
+
return self.state[self.mapping.get(arg, arg)]
|
511
509
|
|
512
510
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
513
511
|
for p, g in zip(p_list, g_list):
|
@@ -538,6 +536,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
538
536
|
p_views = merge_group(group, p)
|
539
537
|
if grad is not None:
|
540
538
|
grad = merge_group(group, grad)
|
539
|
+
for i, pv in enumerate(p_views):
|
540
|
+
self.mapping[pv] = (p, i)
|
541
541
|
if isinstance(p_views, Tensor):
|
542
542
|
yield p_views, grad
|
543
543
|
continue
|
@@ -622,11 +622,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
622
622
|
for top_group in self.param_groups:
|
623
623
|
for group in self.get_groups(top_group):
|
624
624
|
self._step(group)
|
625
|
+
self.mapping.clear()
|
625
626
|
if self.use_ema:
|
626
627
|
self.ema_update(group)
|
628
|
+
|
627
629
|
return loss
|
628
630
|
|
629
631
|
|
632
|
+
|
630
633
|
class ScheduleFree(StatefulOptimizer):
|
631
634
|
def eval(self):
|
632
635
|
for group in self.param_groups:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|