heavyball 0.25.1__tar.gz → 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.25.1 → heavyball-1.1.0}/PKG-INFO +6 -5
- {heavyball-0.25.1 → heavyball-1.1.0}/README.md +5 -4
- heavyball-1.1.0/heavyball/__init__.py +223 -0
- heavyball-1.1.0/heavyball/chainable.py +532 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/heavyball/utils.py +409 -231
- {heavyball-0.25.1 → heavyball-1.1.0}/heavyball.egg-info/PKG-INFO +6 -5
- heavyball-1.1.0/heavyball.egg-info/SOURCES.txt +27 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/setup.py +1 -1
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_bf16_params.py +7 -3
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_caution.py +7 -2
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_foreach.py +1 -1
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_psgd.py +7 -2
- heavyball-0.25.1/heavyball/__init__.py +0 -56
- heavyball-0.25.1/heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball-0.25.1/heavyball/cached_psgd_kron.py +0 -136
- heavyball-0.25.1/heavyball/delayed_psgd.py +0 -122
- heavyball-0.25.1/heavyball/foreach_adamw.py +0 -63
- heavyball-0.25.1/heavyball/foreach_adopt.py +0 -83
- heavyball-0.25.1/heavyball/foreach_laprop.py +0 -67
- heavyball-0.25.1/heavyball/foreach_sfadamw.py +0 -69
- heavyball-0.25.1/heavyball/foreach_soap.py +0 -93
- heavyball-0.25.1/heavyball/foreach_solp.py +0 -89
- heavyball-0.25.1/heavyball/p_adam.py +0 -121
- heavyball-0.25.1/heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball-0.25.1/heavyball/palm_foreach_soap.py +0 -101
- heavyball-0.25.1/heavyball/palm_foreach_solp.py +0 -98
- heavyball-0.25.1/heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball-0.25.1/heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball-0.25.1/heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball-0.25.1/heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball-0.25.1/heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball-0.25.1/heavyball/psgd_kron.py +0 -120
- heavyball-0.25.1/heavyball/pure_psgd.py +0 -105
- heavyball-0.25.1/heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1/heavyball.egg-info/SOURCES.txt +0 -48
- heavyball-0.25.1/test/test_solp.py +0 -50
- {heavyball-0.25.1 → heavyball-1.1.0}/LICENSE +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/setup.cfg +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_bf16_q.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_bf16_storage.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_channels_last.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_closure.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_ema.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_hook.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_mars.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_memory.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_merge.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_no_grad.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_soap.py +0 -0
- {heavyball-0.25.1 → heavyball-1.1.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version:
|
3
|
+
Version: 1.1.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,11 +32,12 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-
|
35
|
+
Currently (2024-12-07, 1.0.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
39
39
|
|
40
|
+
* **Optax-like API**: `C = heavyball.chainable; grokfast = C.ChainOpt(p, lr, C.exp_avg, C.scale_by_adam)`
|
40
41
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
41
42
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
42
43
|
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
@@ -130,6 +131,6 @@ To access `heavyball.utils`, you need to explicitly `import heavyball.utils`.\
|
|
130
131
|
It has several handy functions:
|
131
132
|
|
132
133
|
* `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
|
133
|
-
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls
|
134
|
-
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz
|
135
|
-
the eigenvectors.
|
134
|
+
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
|
135
|
+
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
|
136
|
+
the eigenvectors.
|
@@ -8,11 +8,12 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-
|
11
|
+
Currently (2024-12-07, 1.0.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
15
15
|
|
16
|
+
* **Optax-like API**: `C = heavyball.chainable; grokfast = C.ChainOpt(p, lr, C.exp_avg, C.scale_by_adam)`
|
16
17
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
17
18
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
18
19
|
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
@@ -106,6 +107,6 @@ To access `heavyball.utils`, you need to explicitly `import heavyball.utils`.\
|
|
106
107
|
It has several handy functions:
|
107
108
|
|
108
109
|
* `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
|
109
|
-
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls
|
110
|
-
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz
|
111
|
-
the eigenvectors.
|
110
|
+
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
|
111
|
+
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
|
112
|
+
the eigenvectors.
|
@@ -0,0 +1,223 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from . import chainable as C
|
5
|
+
from . import utils
|
6
|
+
|
7
|
+
|
8
|
+
class ForeachAdamW(C.BaseOpt):
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
10
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
11
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
12
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
13
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
14
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
15
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
16
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
17
|
+
|
18
|
+
|
19
|
+
class ForeachRMSprop(C.BaseOpt):
|
20
|
+
"""
|
21
|
+
Debiased RMSprop (not torch.optim.RMSprop)
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
|
25
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
26
|
+
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
27
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
28
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
|
29
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
|
30
|
+
beta2_scale=beta2_scale)
|
31
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
|
32
|
+
|
33
|
+
|
34
|
+
class ForeachSFAdamW(C.ScheduleFree):
|
35
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
|
36
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
37
|
+
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
38
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
39
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
40
|
+
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
41
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
|
42
|
+
beta2_scale=beta2_scale)
|
43
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
|
44
|
+
C.update_by_schedule_free)
|
45
|
+
|
46
|
+
|
47
|
+
class PaLMForeachSFAdamW(ForeachSFAdamW):
|
48
|
+
palm: bool = True
|
49
|
+
|
50
|
+
|
51
|
+
class ForeachADOPT(C.BaseOpt):
|
52
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
53
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
54
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
55
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
56
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
57
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
58
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
59
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
|
+
|
61
|
+
|
62
|
+
class ForeachMuon(C.BaseOpt):
|
63
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
65
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
|
67
|
+
nesterov: bool = True):
|
68
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
69
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
70
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
71
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
72
|
+
C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
|
73
|
+
|
74
|
+
|
75
|
+
class ForeachLaProp(C.BaseOpt):
|
76
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
77
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
78
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
79
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
80
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
81
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
82
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
83
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
84
|
+
|
85
|
+
|
86
|
+
class ForeachSOAP(C.BaseOpt):
|
87
|
+
"""
|
88
|
+
ForeachSOAP
|
89
|
+
|
90
|
+
Sources:
|
91
|
+
Baseline SOAP:
|
92
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
93
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
94
|
+
https://arxiv.org/abs/2409.11321
|
95
|
+
https://github.com/nikhilvyas/SOAP
|
96
|
+
|
97
|
+
ScheduleFree:
|
98
|
+
The Road Less Scheduled
|
99
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
100
|
+
https://arxiv.org/abs/2405.15682
|
101
|
+
https://github.com/facebookresearch/schedule_free
|
102
|
+
"""
|
103
|
+
use_precond_schedule: bool = False
|
104
|
+
|
105
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
106
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
107
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
108
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
109
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
110
|
+
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
111
|
+
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
112
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
|
113
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
114
|
+
|
115
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
116
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
117
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
118
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
119
|
+
'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
|
120
|
+
'beta2_scale': beta2_scale}
|
121
|
+
if use_precond_schedule:
|
122
|
+
del defaults['precondition_frequency']
|
123
|
+
else:
|
124
|
+
del defaults['precond_scheduler']
|
125
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
126
|
+
C.scale_by_soap)
|
127
|
+
|
128
|
+
|
129
|
+
class PaLMForeachSOAP(ForeachSOAP):
|
130
|
+
use_precond_schedule: bool = False
|
131
|
+
palm: bool = True
|
132
|
+
|
133
|
+
|
134
|
+
class PrecondScheduleForeachSOAP(ForeachSOAP):
|
135
|
+
use_precond_schedule: bool = True
|
136
|
+
|
137
|
+
|
138
|
+
class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
|
139
|
+
use_precond_schedule: bool = True
|
140
|
+
palm: bool = True
|
141
|
+
|
142
|
+
|
143
|
+
class ForeachPSGDKron(C.BaseOpt):
|
144
|
+
"""
|
145
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
146
|
+
Modified under Creative Commons Attribution 4.0 International
|
147
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
148
|
+
"""
|
149
|
+
|
150
|
+
delayed: bool = False
|
151
|
+
cached: bool = False
|
152
|
+
exp_avg_input: bool = True
|
153
|
+
|
154
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
155
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
156
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
157
|
+
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
158
|
+
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
159
|
+
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
160
|
+
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
161
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
162
|
+
# expert parameters
|
163
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
164
|
+
delayed = C.default(delayed, self.delayed)
|
165
|
+
cached = C.default(cached, self.cached)
|
166
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
167
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
168
|
+
|
169
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
170
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
171
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
172
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
173
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
174
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
175
|
+
stochastic_schedule=stochastic_schedule)
|
176
|
+
|
177
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
|
178
|
+
*(C.exp_avg,) * exp_avg_input, #
|
179
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
|
180
|
+
prob=preconditioner_update_probability))
|
181
|
+
|
182
|
+
|
183
|
+
class ForeachPurePSGD(ForeachPSGDKron):
|
184
|
+
exp_avg_input: bool = False
|
185
|
+
|
186
|
+
|
187
|
+
class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
|
188
|
+
delayed: bool = True
|
189
|
+
cached: bool = True
|
190
|
+
|
191
|
+
|
192
|
+
class ForeachCachedPSGDKron(ForeachPSGDKron):
|
193
|
+
cached: bool = True
|
194
|
+
|
195
|
+
|
196
|
+
class ForeachDelayedPSGD(ForeachPSGDKron):
|
197
|
+
delayed: bool = True
|
198
|
+
|
199
|
+
|
200
|
+
PalmForEachSoap = PaLMForeachSOAP
|
201
|
+
PaLMSOAP = PaLMForeachSOAP
|
202
|
+
PaLMSFAdamW = PaLMForeachSFAdamW
|
203
|
+
SOAP = ForeachSOAP
|
204
|
+
SFAdamW = ForeachSFAdamW
|
205
|
+
LaProp = ForeachLaProp
|
206
|
+
ADOPT = ForeachADOPT
|
207
|
+
RMSprop = ForeachRMSprop
|
208
|
+
PrecondScheduleSOAP = PrecondScheduleForeachSOAP
|
209
|
+
PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
|
210
|
+
PSGDKron = ForeachPSGDKron
|
211
|
+
AdamW = ForeachAdamW
|
212
|
+
PurePSGD = ForeachPurePSGD
|
213
|
+
DelayedPSGD = ForeachDelayedPSGD
|
214
|
+
CachedPSGDKron = ForeachCachedPSGDKron
|
215
|
+
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
216
|
+
Muon = ForeachMuon
|
217
|
+
|
218
|
+
__all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
219
|
+
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
220
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
|
221
|
+
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
222
|
+
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
223
|
+
"ForeachRMSprop", "ForeachMuon"]
|