heavyball 0.24.2__py3-none-any.whl → 0.24.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +1 -1
- heavyball/cached_psgd_kron.py +5 -2
- heavyball/delayed_psgd.py +1 -1
- heavyball/foreach_soap.py +2 -3
- heavyball/palm_foreach_soap.py +2 -3
- heavyball/precond_schedule_foreach_soap.py +2 -2
- heavyball/precond_schedule_palm_foreach_soap.py +2 -2
- heavyball/utils.py +5 -4
- {heavyball-0.24.2.dist-info → heavyball-0.24.4.dist-info}/METADATA +1 -1
- heavyball-0.24.4.dist-info/RECORD +24 -0
- heavyball-0.24.2.dist-info/RECORD +0 -24
- {heavyball-0.24.2.dist-info → heavyball-0.24.4.dist-info}/LICENSE +0 -0
- {heavyball-0.24.2.dist-info → heavyball-0.24.4.dist-info}/WHEEL +0 -0
- {heavyball-0.24.2.dist-info → heavyball-0.24.4.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
22
22
|
params (iterable): Iterable of parameters to optimize or dicts defining
|
23
23
|
parameter groups.
|
24
24
|
lr (float): Learning rate.
|
25
|
-
|
25
|
+
beta (float): Momentum parameter.
|
26
26
|
weight_decay (float): Weight decay (L2 penalty).
|
27
27
|
preconditioner_update_probability (callable or float, optional): Probability of
|
28
28
|
updating the preconditioner. If None, defaults to a schedule that anneals
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -19,7 +19,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
19
19
|
params (iterable): Iterable of parameters to optimize or dicts defining
|
20
20
|
parameter groups.
|
21
21
|
lr (float): Learning rate.
|
22
|
-
|
22
|
+
beta (float): Momentum parameter.
|
23
23
|
weight_decay (float): Weight decay (L2 penalty).
|
24
24
|
preconditioner_update_probability (callable or float, optional): Probability of
|
25
25
|
updating the preconditioner. If None, defaults to a schedule that anneals
|
@@ -41,6 +41,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
42
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
43
|
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
+
orthogonalize_output: bool = False,
|
44
45
|
#
|
45
46
|
# expert parameters
|
46
47
|
precond_init_scale=1.0, precond_lr=0.1):
|
@@ -59,7 +60,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
59
60
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
61
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
62
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
62
|
-
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars
|
63
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
64
|
+
orthogonalize_output=orthogonalize_output)
|
63
65
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
66
|
|
65
67
|
def _step(self, group):
|
@@ -75,6 +77,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
75
77
|
store_triu_as_line = group['store_triu_as_line']
|
76
78
|
q_dtype = getattr(torch, group['q_dtype'])
|
77
79
|
storage_dtype = getattr(torch, group['storage_dtype'])
|
80
|
+
orthogonalize_output = group['orthogonalize_output']
|
78
81
|
should_update = self.should_update(group)
|
79
82
|
|
80
83
|
vals = []
|
heavyball/delayed_psgd.py
CHANGED
@@ -25,7 +25,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
25
25
|
params (iterable): Iterable of parameters to optimize or dicts defining
|
26
26
|
parameter groups.
|
27
27
|
lr (float): Learning rate.
|
28
|
-
|
28
|
+
beta (float): Momentum parameter.
|
29
29
|
weight_decay (float): Weight decay (L2 penalty).
|
30
30
|
preconditioner_update_probability (callable or float, optional): Probability of
|
31
31
|
updating the preconditioner. If None, defaults to a schedule that anneals
|
heavyball/foreach_soap.py
CHANGED
@@ -71,12 +71,11 @@ class ForeachSOAP(StatefulOptimizer):
|
|
71
71
|
# Decay the first and second moment running average coefficient
|
72
72
|
# In-place operations to update the averages at the same time
|
73
73
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
74
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
75
74
|
|
76
75
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
77
76
|
|
78
|
-
for p, g, ea,
|
79
|
-
|
77
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
78
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
80
79
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
81
80
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
82
81
|
exp_avg_projected = project(ea, state['Q'], False)
|
heavyball/palm_foreach_soap.py
CHANGED
@@ -81,11 +81,10 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
81
81
|
# In-place operations to update the averages at the same time
|
82
82
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
83
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
84
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
86
85
|
|
87
|
-
for p, g, ea,
|
88
|
-
|
86
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
87
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
89
88
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
90
89
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
91
90
|
exp_avg_projected = project(ea, state['Q'], False)
|
@@ -73,12 +73,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
73
73
|
# Decay the first and second moment running average coefficient
|
74
74
|
# In-place operations to update the averages at the same time
|
75
75
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
76
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
77
76
|
|
78
77
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
79
78
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
80
79
|
|
81
|
-
for p, g, ea,
|
80
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
81
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
82
82
|
state = self.state_(p)
|
83
83
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
84
84
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
@@ -84,12 +84,12 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
84
84
|
# In-place operations to update the averages at the same time
|
85
85
|
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
86
86
|
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
87
|
-
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
88
87
|
|
89
88
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
90
89
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
91
90
|
|
92
|
-
for p, g, ea,
|
91
|
+
for p, g, gp, ea, eas in zip(p_list, grad, grad_projected, exp_avg, exp_avg_sq):
|
92
|
+
d = exp_avg_(ea, eas, g, gp, beta1, beta2, step_tensor)[0]
|
93
93
|
state = self.state_(p)
|
94
94
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
95
95
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
heavyball/utils.py
CHANGED
@@ -23,7 +23,8 @@ def decorator(func):
|
|
23
23
|
|
24
24
|
@functools.wraps(func)
|
25
25
|
def _fn(*args, **kwargs):
|
26
|
-
|
26
|
+
disable = compile_mode_recommended_to_none is None
|
27
|
+
if is_compiling() or compile_mode_recommended_to_none is None:
|
27
28
|
return func(*args, **kwargs)
|
28
29
|
nonlocal compiled
|
29
30
|
if compiled is None:
|
@@ -874,7 +875,7 @@ def psgd_lb(A, max_abs):
|
|
874
875
|
return x
|
875
876
|
|
876
877
|
|
877
|
-
@
|
878
|
+
@decorator
|
878
879
|
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
879
880
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
880
881
|
exprA, exprGs, _ = exprs
|
@@ -1130,11 +1131,11 @@ def merge_group(group, *tensors):
|
|
1130
1131
|
'max_precond_dim'], group.get('split', False)))
|
1131
1132
|
return out
|
1132
1133
|
|
1134
|
+
|
1133
1135
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1134
1136
|
def _step(p: Tensor, o: torch.optim.Optimizer):
|
1135
1137
|
o.step()
|
1136
1138
|
o.zero_grad()
|
1137
1139
|
|
1138
|
-
|
1139
1140
|
for p in model.parameters():
|
1140
|
-
p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
|
1141
|
+
p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
|
@@ -0,0 +1,24 @@
|
|
1
|
+
heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=HEyT6vW6Le6FmWpf-vAEzgbAkPH2mByqXcVZn07KCMk,6866
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=rOgWAeVMENI7kdoBuRo3ywrCeatAnIqBdeYPHuVk2aU,6998
|
4
|
+
heavyball/delayed_psgd.py,sha256=L6qRLPxJmJ_1e0Mk2zLYUEVxkt8NGHq6v3HKawlgFcU,6334
|
5
|
+
heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,2860
|
6
|
+
heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
|
7
|
+
heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
|
8
|
+
heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
|
9
|
+
heavyball/foreach_soap.py,sha256=Tgwg4_Sir9nI_3R85f8NMQagquUBJmAEMQqh0uD3b0Y,4771
|
10
|
+
heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
|
11
|
+
heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
|
12
|
+
heavyball/palm_foreach_soap.py,sha256=zSjpYYm1hfgIudjo_q3ozu3Vkfhz8w8im1c-ou1U3sI,6198
|
13
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=p7oD2bESyCPsdGkJYhHluraDb_1K5Q28RNL6fIvD5C8,4969
|
14
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=Sb3Fhv-EG28_oXnbVpE0iHe5R8i5_hltqoi_DgPuoEU,6505
|
15
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
|
16
|
+
heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
|
17
|
+
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
|
19
|
+
heavyball/utils.py,sha256=AxhcHzbFAvhTgTFyIcdxs9TJkH4AgVEaNeBRjOLzoBM,40095
|
20
|
+
heavyball-0.24.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.24.4.dist-info/METADATA,sha256=oksy8cvcHSdMEs9Mzv7WDAVkkkUPcpA1uYNlUgZM_bk,11926
|
22
|
+
heavyball-0.24.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.24.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.24.4.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=cHwVDq-_284_eMt09rAq26D_8fv3N0e0wdN1woCHU1M,6864
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=ttg6bemNDRpCJBV3aJg2DSyVfsfTMZAnhErgwC2jXlw,6815
|
4
|
-
heavyball/delayed_psgd.py,sha256=yHy83YQ_PKWtwQq1R_OVyj3cjmcbsZAXX1M-hGyciss,6332
|
5
|
-
heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,2860
|
6
|
-
heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
|
7
|
-
heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
|
8
|
-
heavyball/foreach_sfadamw.py,sha256=TeWf0nKXQEFcz02rADYRJenDM9mX1dGHhvILLks6OW8,3087
|
9
|
-
heavyball/foreach_soap.py,sha256=408jRysE9ek0ea-TphhSBMTa9zcjkgMX3qlx8qTCt34,4803
|
10
|
-
heavyball/p_adam.py,sha256=qEcuU8VEc35vaWAXjT0O65vfCuNn_3ttwL4RlJKN3Xw,6389
|
11
|
-
heavyball/palm_foreach_sfadamw.py,sha256=1qOr-uniSmI1sNCJc1SnvyKH5iFu80Z6H5h93lDTwcE,3410
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=cExM9nTC3zAgsRr42VOIMWNwYA4dAJaA8-pIo7SWilc,6230
|
13
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=EL_Z-v5l7BC98QgI-Zg9iyM77TAreVgD5Zln59ewGoI,4966
|
14
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=HWo2t7yY-_n4pPGmDiELccy0jdELTVhdlH-eyFBih5k,6502
|
15
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_VuzASluIo,7565
|
16
|
-
heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
|
17
|
-
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
|
19
|
-
heavyball/utils.py,sha256=FglgQfiE206I07rql3qP-X2C1j0hY3N5VcQwKUh08aA,40025
|
20
|
-
heavyball-0.24.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
-
heavyball-0.24.2.dist-info/METADATA,sha256=lTThJQbW6qbnQqy9lGlTTOttJcX5vfQ_s6Cm0arqfC8,11926
|
22
|
-
heavyball-0.24.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
-
heavyball-0.24.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
-
heavyball-0.24.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|