heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +193 -40
- heavyball/chainable.py +475 -0
- heavyball/utils.py +318 -187
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/METADATA +4 -3
- heavyball-1.0.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -93
- heavyball/foreach_solp.py +0 -89
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -101
- heavyball/palm_foreach_solp.py +0 -98
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_foreach_solp.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_palm_foreach_solp.py +0 -103
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.1.dist-info/RECORD +0 -28
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.1.dist-info → heavyball-1.0.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,56 +1,209 @@
|
|
1
|
-
|
2
|
-
from
|
3
|
-
from .delayed_psgd import ForeachDelayedPSGD
|
4
|
-
from .foreach_adamw import ForeachAdamW
|
5
|
-
from .foreach_adopt import ForeachADOPT
|
6
|
-
from .foreach_laprop import ForeachLaProp
|
7
|
-
from .foreach_sfadamw import ForeachSFAdamW
|
8
|
-
from .foreach_soap import ForeachSOAP
|
9
|
-
from .p_adam import ForeachPaLMPAdam
|
10
|
-
from .palm_foreach_sfadamw import PaLMForeachSFAdamW
|
11
|
-
from .palm_foreach_soap import PaLMForeachSOAP
|
12
|
-
from .precond_schedule_foreach_soap import PrecondScheduleForeachSOAP
|
13
|
-
from .precond_schedule_palm_foreach_soap import PrecondSchedulePaLMForeachSOAP
|
14
|
-
from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
|
15
|
-
from .psgd_kron import ForeachPSGDKron
|
16
|
-
from .pure_psgd import ForeachPurePSGD
|
17
|
-
from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
18
|
-
from .foreach_solp import ForeachSOLP
|
19
|
-
from .palm_foreach_solp import PaLMForeachSOLP
|
20
|
-
from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
|
21
|
-
from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
|
1
|
+
import functools
|
2
|
+
from typing import Optional
|
22
3
|
|
23
|
-
|
4
|
+
from . import chainable as C
|
5
|
+
from . import utils
|
6
|
+
|
7
|
+
|
8
|
+
class ForeachAdamW(C.BaseOpt):
|
9
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
10
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
11
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
12
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
13
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
14
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
15
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
16
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
17
|
+
|
18
|
+
|
19
|
+
class ForeachRMSprop(C.BaseOpt):
|
20
|
+
"""
|
21
|
+
Debiased RMSprop (not torch.optim.RMSprop)
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
|
25
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
26
|
+
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
27
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
28
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
|
29
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
|
30
|
+
beta2_scale=beta2_scale)
|
31
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
|
32
|
+
|
33
|
+
|
34
|
+
class ForeachSFAdamW(C.ScheduleFree):
|
35
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
|
36
|
+
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
37
|
+
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
38
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
39
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
|
40
|
+
weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
|
41
|
+
foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
|
42
|
+
beta2_scale=beta2_scale)
|
43
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
|
44
|
+
C.update_by_schedule_free)
|
45
|
+
|
46
|
+
|
47
|
+
class PaLMForeachSFAdamW(ForeachSFAdamW):
|
48
|
+
palm: bool = True
|
49
|
+
|
50
|
+
|
51
|
+
class ForeachADOPT(C.BaseOpt):
|
52
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
53
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
54
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
55
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
56
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
57
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
58
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
59
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
|
+
|
61
|
+
|
62
|
+
class ForeachLaProp(C.BaseOpt):
|
63
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
64
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
65
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
67
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
68
|
+
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
69
|
+
mars_gamma=mars_gamma, beta2_scale=beta2_scale)
|
70
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
71
|
+
|
72
|
+
|
73
|
+
class ForeachSOAP(C.BaseOpt):
|
74
|
+
"""
|
75
|
+
ForeachSOAP
|
76
|
+
|
77
|
+
Sources:
|
78
|
+
Baseline SOAP:
|
79
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
80
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
81
|
+
https://arxiv.org/abs/2409.11321
|
82
|
+
https://github.com/nikhilvyas/SOAP
|
83
|
+
|
84
|
+
ScheduleFree:
|
85
|
+
The Road Less Scheduled
|
86
|
+
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
87
|
+
https://arxiv.org/abs/2405.15682
|
88
|
+
https://github.com/facebookresearch/schedule_free
|
89
|
+
"""
|
90
|
+
use_precond_schedule: bool = False
|
91
|
+
|
92
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
93
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
94
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
95
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
96
|
+
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
97
|
+
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
98
|
+
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
99
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
|
100
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
101
|
+
|
102
|
+
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
103
|
+
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
104
|
+
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
105
|
+
"correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
|
106
|
+
'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
|
107
|
+
'beta2_scale': beta2_scale}
|
108
|
+
if use_precond_schedule:
|
109
|
+
del defaults['precondition_frequency']
|
110
|
+
else:
|
111
|
+
del defaults['precond_scheduler']
|
112
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
113
|
+
C.scale_by_soap)
|
114
|
+
|
115
|
+
|
116
|
+
class PaLMForeachSOAP(ForeachSOAP):
|
117
|
+
use_precond_schedule: bool = False
|
118
|
+
palm: bool = True
|
24
119
|
|
120
|
+
|
121
|
+
class PrecondScheduleForeachSOAP(ForeachSOAP):
|
122
|
+
use_precond_schedule: bool = True
|
123
|
+
|
124
|
+
|
125
|
+
class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
|
126
|
+
use_precond_schedule: bool = True
|
127
|
+
palm: bool = True
|
128
|
+
|
129
|
+
|
130
|
+
class ForeachPSGDKron(C.BaseOpt):
|
131
|
+
"""
|
132
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
133
|
+
Modified under Creative Commons Attribution 4.0 International
|
134
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
135
|
+
"""
|
136
|
+
|
137
|
+
delayed: bool = False
|
138
|
+
cached: bool = False
|
139
|
+
exp_avg_input: bool = True
|
140
|
+
|
141
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
142
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
143
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
144
|
+
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
145
|
+
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
146
|
+
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
147
|
+
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
148
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
149
|
+
# expert parameters
|
150
|
+
precond_init_scale=1.0, precond_lr=0.1):
|
151
|
+
delayed = C.default(delayed, self.delayed)
|
152
|
+
cached = C.default(cached, self.cached)
|
153
|
+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
154
|
+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
155
|
+
|
156
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
157
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
158
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
159
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
160
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
161
|
+
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
162
|
+
stochastic_schedule=stochastic_schedule)
|
163
|
+
|
164
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
|
165
|
+
*(C.exp_avg,) * exp_avg_input, #
|
166
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
|
167
|
+
prob=preconditioner_update_probability))
|
168
|
+
|
169
|
+
|
170
|
+
class ForeachPurePSGD(ForeachPSGDKron):
|
171
|
+
exp_avg_input: bool = False
|
172
|
+
|
173
|
+
|
174
|
+
class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
|
175
|
+
delayed: bool = True
|
176
|
+
cached: bool = True
|
177
|
+
|
178
|
+
|
179
|
+
class ForeachCachedPSGDKron(ForeachPSGDKron):
|
180
|
+
cached: bool = True
|
181
|
+
|
182
|
+
|
183
|
+
class ForeachDelayedPSGD(ForeachPSGDKron):
|
184
|
+
delayed: bool = True
|
185
|
+
|
186
|
+
|
187
|
+
PalmForEachSoap = PaLMForeachSOAP
|
25
188
|
PaLMSOAP = PaLMForeachSOAP
|
26
189
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
27
|
-
PaLMSFSoap = SFPaLMForeachSOAP
|
28
|
-
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
29
190
|
SOAP = ForeachSOAP
|
30
191
|
SFAdamW = ForeachSFAdamW
|
31
192
|
LaProp = ForeachLaProp
|
32
193
|
ADOPT = ForeachADOPT
|
194
|
+
RMSprop = ForeachRMSprop
|
33
195
|
PrecondScheduleSOAP = PrecondScheduleForeachSOAP
|
34
196
|
PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
|
35
197
|
PSGDKron = ForeachPSGDKron
|
36
198
|
AdamW = ForeachAdamW
|
37
199
|
PurePSGD = ForeachPurePSGD
|
38
|
-
PaLMPAdam = ForeachPaLMPAdam
|
39
200
|
DelayedPSGD = ForeachDelayedPSGD
|
40
201
|
CachedPSGDKron = ForeachCachedPSGDKron
|
41
202
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
50
|
-
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
|
51
|
-
'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
|
52
|
-
#
|
53
|
-
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
|
54
|
-
'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
|
55
|
-
'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
|
56
|
-
'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
|
203
|
+
|
204
|
+
__all__ = ["PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron", "CachedDelayedPSGDKron",
|
205
|
+
"PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT", "PrecondScheduleSOAP",
|
206
|
+
"PrecondSchedulePaLMSOAP", 'RMSprop', #
|
207
|
+
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
208
|
+
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
209
|
+
"ForeachRMSprop"]
|