heavyball 0.24.2__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/cached_delayed_psgd_kron.py +1 -1
- heavyball/cached_psgd_kron.py +5 -2
- heavyball/delayed_psgd.py +1 -1
- heavyball/utils.py +5 -4
- {heavyball-0.24.2.dist-info → heavyball-0.24.3.dist-info}/METADATA +1 -1
- {heavyball-0.24.2.dist-info → heavyball-0.24.3.dist-info}/RECORD +9 -9
- {heavyball-0.24.2.dist-info → heavyball-0.24.3.dist-info}/LICENSE +0 -0
- {heavyball-0.24.2.dist-info → heavyball-0.24.3.dist-info}/WHEEL +0 -0
- {heavyball-0.24.2.dist-info → heavyball-0.24.3.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
22
22
|
params (iterable): Iterable of parameters to optimize or dicts defining
|
23
23
|
parameter groups.
|
24
24
|
lr (float): Learning rate.
|
25
|
-
|
25
|
+
beta (float): Momentum parameter.
|
26
26
|
weight_decay (float): Weight decay (L2 penalty).
|
27
27
|
preconditioner_update_probability (callable or float, optional): Probability of
|
28
28
|
updating the preconditioner. If None, defaults to a schedule that anneals
|
heavyball/cached_psgd_kron.py
CHANGED
@@ -19,7 +19,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
19
19
|
params (iterable): Iterable of parameters to optimize or dicts defining
|
20
20
|
parameter groups.
|
21
21
|
lr (float): Learning rate.
|
22
|
-
|
22
|
+
beta (float): Momentum parameter.
|
23
23
|
weight_decay (float): Weight decay (L2 penalty).
|
24
24
|
preconditioner_update_probability (callable or float, optional): Probability of
|
25
25
|
updating the preconditioner. If None, defaults to a schedule that anneals
|
@@ -41,6 +41,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
42
|
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
43
|
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
+
orthogonalize_output: bool = False,
|
44
45
|
#
|
45
46
|
# expert parameters
|
46
47
|
precond_init_scale=1.0, precond_lr=0.1):
|
@@ -59,7 +60,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
59
60
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
61
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
62
|
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
62
|
-
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars
|
63
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
64
|
+
orthogonalize_output=orthogonalize_output)
|
63
65
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
64
66
|
|
65
67
|
def _step(self, group):
|
@@ -75,6 +77,7 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
75
77
|
store_triu_as_line = group['store_triu_as_line']
|
76
78
|
q_dtype = getattr(torch, group['q_dtype'])
|
77
79
|
storage_dtype = getattr(torch, group['storage_dtype'])
|
80
|
+
orthogonalize_output = group['orthogonalize_output']
|
78
81
|
should_update = self.should_update(group)
|
79
82
|
|
80
83
|
vals = []
|
heavyball/delayed_psgd.py
CHANGED
@@ -25,7 +25,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
25
25
|
params (iterable): Iterable of parameters to optimize or dicts defining
|
26
26
|
parameter groups.
|
27
27
|
lr (float): Learning rate.
|
28
|
-
|
28
|
+
beta (float): Momentum parameter.
|
29
29
|
weight_decay (float): Weight decay (L2 penalty).
|
30
30
|
preconditioner_update_probability (callable or float, optional): Probability of
|
31
31
|
updating the preconditioner. If None, defaults to a schedule that anneals
|
heavyball/utils.py
CHANGED
@@ -23,7 +23,8 @@ def decorator(func):
|
|
23
23
|
|
24
24
|
@functools.wraps(func)
|
25
25
|
def _fn(*args, **kwargs):
|
26
|
-
|
26
|
+
disable = compile_mode_recommended_to_none is None
|
27
|
+
if is_compiling() or compile_mode_recommended_to_none is None:
|
27
28
|
return func(*args, **kwargs)
|
28
29
|
nonlocal compiled
|
29
30
|
if compiled is None:
|
@@ -874,7 +875,7 @@ def psgd_lb(A, max_abs):
|
|
874
875
|
return x
|
875
876
|
|
876
877
|
|
877
|
-
@
|
878
|
+
@decorator
|
878
879
|
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
879
880
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
880
881
|
exprA, exprGs, _ = exprs
|
@@ -1130,11 +1131,11 @@ def merge_group(group, *tensors):
|
|
1130
1131
|
'max_precond_dim'], group.get('split', False)))
|
1131
1132
|
return out
|
1132
1133
|
|
1134
|
+
|
1133
1135
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1134
1136
|
def _step(p: Tensor, o: torch.optim.Optimizer):
|
1135
1137
|
o.step()
|
1136
1138
|
o.zero_grad()
|
1137
1139
|
|
1138
|
-
|
1139
1140
|
for p in model.parameters():
|
1140
|
-
p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
|
1141
|
+
p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
|
@@ -1,7 +1,7 @@
|
|
1
1
|
heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=
|
4
|
-
heavyball/delayed_psgd.py,sha256=
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=HEyT6vW6Le6FmWpf-vAEzgbAkPH2mByqXcVZn07KCMk,6866
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=rOgWAeVMENI7kdoBuRo3ywrCeatAnIqBdeYPHuVk2aU,6998
|
4
|
+
heavyball/delayed_psgd.py,sha256=L6qRLPxJmJ_1e0Mk2zLYUEVxkt8NGHq6v3HKawlgFcU,6334
|
5
5
|
heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,2860
|
6
6
|
heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
|
7
7
|
heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_V
|
|
16
16
|
heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
|
17
17
|
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.24.
|
21
|
-
heavyball-0.24.
|
22
|
-
heavyball-0.24.
|
23
|
-
heavyball-0.24.
|
24
|
-
heavyball-0.24.
|
19
|
+
heavyball/utils.py,sha256=AxhcHzbFAvhTgTFyIcdxs9TJkH4AgVEaNeBRjOLzoBM,40095
|
20
|
+
heavyball-0.24.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.24.3.dist-info/METADATA,sha256=32T-Q-a4k096KjxoR-3DQt25XpO_h0zs7lWKTDQLugI,11926
|
22
|
+
heavyball-0.24.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.24.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.24.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|