heavyball 0.25.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +207 -40
- heavyball/chainable.py +532 -0
- heavyball/utils.py +409 -231
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/METADATA +6 -5
- heavyball-1.1.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -93
- heavyball/foreach_solp.py +0 -89
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -101
- heavyball/palm_foreach_solp.py +0 -98
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1.dist-info/RECORD +0 -28
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.1.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,56 +1,223 @@
|
|
1
|
-
|
2
|
-
from
|
3
|
-
|
4
|
-
from .
|
5
|
-
from .
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
1
|
+
import functools
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from . import chainable as C
|
5
|
+
from . import utils
|
6
|
+
|
7
|
+
|
8
|
+
class ForeachAdamW(C.BaseOpt):
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
10
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
11
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
12
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
13
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
14
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
15
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
16
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
17
|
+
|
18
|
+
|
19
|
+
class ForeachRMSprop(C.BaseOpt):
|
20
|
+
"""
|
21
|
+
Debiased RMSprop (not torch.optim.RMSprop)
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
|
25
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
26
|
+
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
27
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
28
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
|
29
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
|
30
|
+
beta2_scale=beta2_scale)
|
31
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
|
32
|
+
|
33
|
+
|
34
|
+
class ForeachSFAdamW(C.ScheduleFree):
|
35
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
|
36
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
37
|
+
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
38
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
39
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
40
|
+
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
41
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
|
42
|
+
beta2_scale=beta2_scale)
|
43
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
|
44
|
+
C.update_by_schedule_free)
|
45
|
+
|
46
|
+
|
47
|
+
class PaLMForeachSFAdamW(ForeachSFAdamW):
|
48
|
+
palm: bool = True
|
49
|
+
|
50
|
+
|
51
|
+
class ForeachADOPT(C.BaseOpt):
|
52
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
53
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
54
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
55
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
56
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
57
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
58
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
59
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
|
+
|
61
|
+
|
62
|
+
class ForeachMuon(C.BaseOpt):
|
63
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
65
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
|
67
|
+
nesterov: bool = True):
|
68
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
69
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
70
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
71
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
72
|
+
C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
|
73
|
+
|
74
|
+
|
75
|
+
class ForeachLaProp(C.BaseOpt):
|
76
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
77
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
78
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
79
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
80
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
81
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
82
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
83
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
84
|
+
|
85
|
+
|
86
|
+
class ForeachSOAP(C.BaseOpt):
|
87
|
+
"""
|
88
|
+
ForeachSOAP
|
89
|
+
|
90
|
+
Sources:
|
91
|
+
Baseline SOAP:
|
92
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
93
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
94
|
+
https://arxiv.org/abs/2409.11321
|
95
|
+
https://github.com/nikhilvyas/SOAP
|
96
|
+
|
97
|
+
ScheduleFree:
|
98
|
+
The Road Less Scheduled
|
99
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
100
|
+
https://arxiv.org/abs/2405.15682
|
101
|
+
https://github.com/facebookresearch/schedule_free
|
102
|
+
"""
|
103
|
+
use_precond_schedule: bool = False
|
104
|
+
|
105
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
106
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
107
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
108
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
109
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
110
|
+
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
111
|
+
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
112
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
|
113
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
114
|
+
|
115
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
116
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
117
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
118
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
119
|
+
'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
|
120
|
+
'beta2_scale': beta2_scale}
|
121
|
+
if use_precond_schedule:
|
122
|
+
del defaults['precondition_frequency']
|
123
|
+
else:
|
124
|
+
del defaults['precond_scheduler']
|
125
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
126
|
+
C.scale_by_soap)
|
22
127
|
|
23
|
-
PalmForEachSoap = PaLMForeachSOAP
|
24
128
|
|
129
|
+
class PaLMForeachSOAP(ForeachSOAP):
|
130
|
+
use_precond_schedule: bool = False
|
131
|
+
palm: bool = True
|
132
|
+
|
133
|
+
|
134
|
+
class PrecondScheduleForeachSOAP(ForeachSOAP):
|
135
|
+
use_precond_schedule: bool = True
|
136
|
+
|
137
|
+
|
138
|
+
class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
|
139
|
+
use_precond_schedule: bool = True
|
140
|
+
palm: bool = True
|
141
|
+
|
142
|
+
|
143
|
+
class ForeachPSGDKron(C.BaseOpt):
|
144
|
+
"""
|
145
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
146
|
+
Modified under Creative Commons Attribution 4.0 International
|
147
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
148
|
+
"""
|
149
|
+
|
150
|
+
delayed: bool = False
|
151
|
+
cached: bool = False
|
152
|
+
exp_avg_input: bool = True
|
153
|
+
|
154
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
155
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
156
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
157
|
+
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
158
|
+
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
159
|
+
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
160
|
+
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
161
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
162
|
+
# expert parameters
|
163
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
164
|
+
delayed = C.default(delayed, self.delayed)
|
165
|
+
cached = C.default(cached, self.cached)
|
166
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
167
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
168
|
+
|
169
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
170
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
171
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
172
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
173
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
174
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
175
|
+
stochastic_schedule=stochastic_schedule)
|
176
|
+
|
177
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
|
178
|
+
*(C.exp_avg,) * exp_avg_input, #
|
179
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
|
180
|
+
prob=preconditioner_update_probability))
|
181
|
+
|
182
|
+
|
183
|
+
class ForeachPurePSGD(ForeachPSGDKron):
|
184
|
+
exp_avg_input: bool = False
|
185
|
+
|
186
|
+
|
187
|
+
class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
|
188
|
+
delayed: bool = True
|
189
|
+
cached: bool = True
|
190
|
+
|
191
|
+
|
192
|
+
class ForeachCachedPSGDKron(ForeachPSGDKron):
|
193
|
+
cached: bool = True
|
194
|
+
|
195
|
+
|
196
|
+
class ForeachDelayedPSGD(ForeachPSGDKron):
|
197
|
+
delayed: bool = True
|
198
|
+
|
199
|
+
|
200
|
+
PalmForEachSoap = PaLMForeachSOAP
|
25
201
|
PaLMSOAP = PaLMForeachSOAP
|
26
202
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
27
|
-
PaLMSFSoap = SFPaLMForeachSOAP
|
28
|
-
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
29
203
|
SOAP = ForeachSOAP
|
30
204
|
SFAdamW = ForeachSFAdamW
|
31
205
|
LaProp = ForeachLaProp
|
32
206
|
ADOPT = ForeachADOPT
|
207
|
+
RMSprop = ForeachRMSprop
|
33
208
|
PrecondScheduleSOAP = PrecondScheduleForeachSOAP
|
34
209
|
PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
|
35
210
|
PSGDKron = ForeachPSGDKron
|
36
211
|
AdamW = ForeachAdamW
|
37
212
|
PurePSGD = ForeachPurePSGD
|
38
|
-
PaLMPAdam = ForeachPaLMPAdam
|
39
213
|
DelayedPSGD = ForeachDelayedPSGD
|
40
214
|
CachedPSGDKron = ForeachCachedPSGDKron
|
41
215
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
|
51
|
-
'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
|
52
|
-
#
|
53
|
-
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
|
54
|
-
'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
|
55
|
-
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
|
56
|
-
'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
|
216
|
+
Muon = ForeachMuon
|
217
|
+
|
218
|
+
__all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
219
|
+
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
220
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
|
221
|
+
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
222
|
+
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
223
|
+
"ForeachRMSprop", "ForeachMuon"]
|