heavyball 1.5.2__tar.gz → 1.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.5.2 → heavyball-1.6.0}/PKG-INFO +4 -6
- {heavyball-1.5.2 → heavyball-1.6.0}/README.md +1 -3
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball/__init__.py +73 -17
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball/chainable.py +76 -14
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball/utils.py +322 -175
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/PKG-INFO +4 -6
- {heavyball-1.5.2 → heavyball-1.6.0}/setup.py +3 -3
- {heavyball-1.5.2 → heavyball-1.6.0}/LICENSE +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/setup.cfg +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_bf16_params.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_bf16_q.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_bf16_storage.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_caution.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_channels_last.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_closure.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_ema.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_foreach.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_hook.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_mars.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_memory.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_merge.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_no_grad.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_psgd.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_soap.py +0 -0
- {heavyball-1.5.2 → heavyball-1.6.0}/test/test_stochastic_updates.py +0 -0
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.6.0
|
4
4
|
Summary: Efficient optimizers
|
5
|
-
Home-page: https://github.com/
|
6
|
-
Author:
|
5
|
+
Home-page: https://github.com/HomebrewML/HeavyBall
|
6
|
+
Author: HeavyBall Authors
|
7
7
|
Author-email: github.heavyball@nestler.sh
|
8
8
|
License: BSD
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
300
300
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
301
301
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
302
302
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
303
|
-
|
303
|
+
correct_bias: bool = True, warmup_steps: int = 1,
|
304
304
|
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
305
305
|
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
306
306
|
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
|
|
324
324
|
* **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
|
325
325
|
* **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
|
326
326
|
* **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
|
327
|
-
* **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
|
328
327
|
* **`correct_bias`**: Enables/disables bias correction for the running averages.
|
329
328
|
* **`warmup_steps`**: Number of steps for linear learning rate warmup.
|
330
329
|
* **`split`**: Whether to split large dimensions when forming the preconditioner.
|
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
|
|
931
930
|
* **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
|
932
931
|
`compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
|
933
932
|
configurations.
|
934
|
-
|
@@ -276,7 +276,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
276
276
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
277
277
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
278
278
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
279
|
-
|
279
|
+
correct_bias: bool = True, warmup_steps: int = 1,
|
280
280
|
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
281
281
|
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
282
282
|
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
@@ -300,7 +300,6 @@ the second-order statistics of the gradients to accelerate convergence.
|
|
300
300
|
* **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
|
301
301
|
* **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
|
302
302
|
* **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
|
303
|
-
* **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
|
304
303
|
* **`correct_bias`**: Enables/disables bias correction for the running averages.
|
305
304
|
* **`warmup_steps`**: Number of steps for linear learning rate warmup.
|
306
305
|
* **`split`**: Whether to split large dimensions when forming the preconditioner.
|
@@ -907,4 +906,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
|
|
907
906
|
* **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
|
908
907
|
`compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
|
909
908
|
configurations.
|
910
|
-
|
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
|
|
104
104
|
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
105
105
|
https://arxiv.org/abs/2409.11321
|
106
106
|
https://github.com/nikhilvyas/SOAP
|
107
|
-
|
108
|
-
ScheduleFree:
|
109
|
-
The Road Less Scheduled
|
110
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
111
|
-
https://arxiv.org/abs/2405.15682
|
112
|
-
https://github.com/facebookresearch/schedule_free
|
113
107
|
"""
|
114
108
|
use_precond_schedule: bool = False
|
115
109
|
|
116
110
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
117
111
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
118
112
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
113
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
114
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
115
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
116
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
117
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
124
118
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
125
119
|
|
126
120
|
defaults = locals()
|
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
|
|
137
131
|
C.scale_by_soap)
|
138
132
|
|
139
133
|
|
134
|
+
class ForeachSignLaProp(C.BaseOpt):
|
135
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
136
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
137
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
138
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
139
|
+
defaults = locals()
|
140
|
+
defaults.pop("self")
|
141
|
+
params = defaults.pop("params")
|
142
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
|
143
|
+
|
144
|
+
|
145
|
+
class ForeachSOLP(C.BaseOpt):
|
146
|
+
"""
|
147
|
+
ForeachSOLP
|
148
|
+
|
149
|
+
Sources:
|
150
|
+
Baseline SOAP:
|
151
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
152
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
153
|
+
https://arxiv.org/abs/2409.11321
|
154
|
+
https://github.com/nikhilvyas/SOAP
|
155
|
+
"""
|
156
|
+
use_precond_schedule: bool = False
|
157
|
+
|
158
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
159
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
160
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
161
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
162
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
163
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
164
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
165
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
166
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
167
|
+
|
168
|
+
defaults = locals()
|
169
|
+
defaults.pop("self")
|
170
|
+
params = defaults.pop("params")
|
171
|
+
|
172
|
+
if use_precond_schedule:
|
173
|
+
del defaults['precondition_frequency']
|
174
|
+
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
175
|
+
else:
|
176
|
+
del defaults['precond_scheduler']
|
177
|
+
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
178
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
179
|
+
functools.partial(C.scale_by_soap, inner='laprop'))
|
180
|
+
|
181
|
+
|
140
182
|
class PaLMForeachSOAP(ForeachSOAP):
|
141
183
|
use_precond_schedule: bool = False
|
142
184
|
palm: bool = True
|
@@ -163,6 +205,18 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
205
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
206
|
|
165
207
|
|
208
|
+
class LaPropOrtho(C.BaseOpt):
|
209
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
210
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
211
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
212
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
213
|
+
defaults = locals()
|
214
|
+
defaults.pop("self")
|
215
|
+
params = defaults.pop("params")
|
216
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
|
217
|
+
C.orthogonalize_grad_to_param)
|
218
|
+
|
219
|
+
|
166
220
|
class ForeachPSGDKron(C.BaseOpt):
|
167
221
|
"""
|
168
222
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -178,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
178
232
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
179
233
|
momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
|
180
234
|
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
181
|
-
stochastic_schedule: bool =
|
235
|
+
stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
|
182
236
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
183
237
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
184
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
238
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
185
239
|
# expert parameters
|
186
240
|
precond_init_scale=1.0, precond_lr=0.1):
|
187
241
|
defaults = locals()
|
@@ -238,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
|
|
238
292
|
CachedPSGDKron = ForeachCachedPSGDKron
|
239
293
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
240
294
|
Muon = ForeachMuon
|
295
|
+
SignLaProp = ForeachSignLaProp
|
241
296
|
|
242
297
|
__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
243
298
|
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
244
|
-
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
245
|
-
|
246
|
-
"
|
247
|
-
"ForeachRMSprop", "ForeachMuon",
|
299
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
|
300
|
+
"ForeachAdamW", "ForeachSFAdamW",
|
301
|
+
"ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
|
302
|
+
"ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
|
303
|
+
'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
from typing import Optional, Union, Literal
|
3
|
+
from typing import Optional, Union, Literal, List
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
@@ -127,7 +127,7 @@ def zero_guard(*names):
|
|
127
127
|
|
128
128
|
|
129
129
|
def copy_guard(index, *names):
|
130
|
-
return functools.partial(CopyGuard, index=index, names=names
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names)
|
131
131
|
|
132
132
|
|
133
133
|
def general_guard(*names, init_fn, skip_first: bool = True):
|
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
152
152
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
153
153
|
|
154
154
|
|
155
|
+
@zero_guard('exp_avg')
|
156
|
+
@no_state
|
157
|
+
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
158
|
+
utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
159
|
+
group['weight_decay_to_ema'] * group['lr'])
|
160
|
+
return update
|
161
|
+
|
162
|
+
|
163
|
+
@zero_guard('exp_avg')
|
164
|
+
@no_state
|
165
|
+
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
166
|
+
utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
167
|
+
group['weight_decay_to_ema'] * group['lr'])
|
168
|
+
return update
|
169
|
+
|
170
|
+
|
155
171
|
@zero_guard("exp_avg_sq")
|
156
172
|
@no_state
|
157
173
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
@@ -206,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
206
222
|
@no_state
|
207
223
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
208
224
|
if group['step'] == 1:
|
209
|
-
utils.
|
225
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
210
226
|
raise SkipUpdate
|
211
227
|
|
212
228
|
if group['step'] == 2:
|
213
229
|
update = utils.promote(update)
|
214
230
|
easq = utils.promote(exp_avg_sq)
|
215
231
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
216
|
-
utils.
|
232
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
233
|
+
group['eps'])
|
217
234
|
raise SkipUpdate
|
218
235
|
|
219
236
|
utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
@@ -225,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
225
242
|
@no_state
|
226
243
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
227
244
|
if group['step'] == 1:
|
228
|
-
utils.
|
245
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
229
246
|
raise SkipUpdate
|
230
247
|
|
231
248
|
if group['step'] == 2:
|
232
249
|
update = utils.promote(update)
|
233
250
|
easq = utils.promote(exp_avg_sq)
|
234
251
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
235
|
-
utils.
|
252
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
253
|
+
group['eps'])
|
236
254
|
raise SkipUpdate
|
237
255
|
|
238
256
|
return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
|
239
257
|
|
240
258
|
|
241
|
-
def _init_soap(state, group, update, grad, param):
|
242
|
-
utils.init_preconditioner(grad, state,
|
259
|
+
def _init_soap(state, group, update, grad, param, inner: str = ''):
|
260
|
+
utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
|
243
261
|
|
244
262
|
|
245
263
|
def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
@@ -295,6 +313,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
|
|
295
313
|
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
296
314
|
|
297
315
|
|
316
|
+
@zero_guard('momentum')
|
317
|
+
@no_state
|
318
|
+
def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
|
319
|
+
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
|
320
|
+
|
321
|
+
|
322
|
+
def _store_std(state, group, update, grad, param):
|
323
|
+
state['init_std'] = torch.std(grad, dim=0)
|
324
|
+
|
325
|
+
|
326
|
+
@general_guard("init_std", init_fn=_store_std)
|
327
|
+
@no_state
|
328
|
+
def mup_approx(group, updates, grads, params, init_std):
|
329
|
+
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
330
|
+
_updates, _init_std = zip(*_updates)
|
331
|
+
utils.stochastic_multiply_(_updates, _init_std)
|
332
|
+
return updates
|
333
|
+
|
334
|
+
|
298
335
|
@zero_guard("momentum")
|
299
336
|
@no_state
|
300
337
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
@@ -308,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
|
|
308
345
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
309
346
|
@no_state
|
310
347
|
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
|
311
|
-
update = utils.promote(update)
|
348
|
+
update = utils.promote(update) # Promote to highest precision if needed
|
312
349
|
|
313
350
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
314
351
|
fn = _optim_fns[inner]
|
315
|
-
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step']
|
352
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
|
353
|
+
group['eps'])
|
316
354
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
317
355
|
|
318
|
-
for u, q, gg,
|
319
|
-
utils.update_preconditioner(u, q, gg,
|
356
|
+
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
357
|
+
utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
|
320
358
|
utils.beta_debias(group['shampoo_beta'], group['step']),
|
321
359
|
group['is_preconditioning'])
|
322
360
|
return precond
|
@@ -414,6 +452,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
|
|
414
452
|
raise SkipUpdate
|
415
453
|
|
416
454
|
|
455
|
+
@no_state
|
456
|
+
def sign(group, update, grad, param, graft: bool = True):
|
457
|
+
return utils.sign_(update, graft)
|
458
|
+
|
459
|
+
|
417
460
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
418
461
|
@no_state_no_foreach
|
419
462
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
@@ -439,8 +482,7 @@ def apply_to_idx(fn, idx):
|
|
439
482
|
return _fn
|
440
483
|
|
441
484
|
|
442
|
-
def
|
443
|
-
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
485
|
+
def _inner_chain(state, group, update, grad, param, *fns):
|
444
486
|
skip_update = False
|
445
487
|
for fn in fns:
|
446
488
|
try:
|
@@ -450,10 +492,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
450
492
|
continue
|
451
493
|
if update is None:
|
452
494
|
break
|
495
|
+
return update, skip_update
|
496
|
+
|
497
|
+
|
498
|
+
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
499
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
500
|
+
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
|
453
501
|
if not skip_update and update is not None:
|
454
502
|
utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
|
455
503
|
|
456
504
|
|
505
|
+
def create_branch(branches: List[List[callable]], merge_fn: callable):
|
506
|
+
def _branch(state, group, update, grad, param):
|
507
|
+
outputs = []
|
508
|
+
for branch in branches:
|
509
|
+
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
510
|
+
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
511
|
+
if skip_update:
|
512
|
+
raise ValueError("Branches should not skip updates")
|
513
|
+
outputs.append(branch_update)
|
514
|
+
return merge_fn(outputs)
|
515
|
+
|
516
|
+
return _branch
|
517
|
+
|
518
|
+
|
457
519
|
class ChainOpt(utils.StatefulOptimizer):
|
458
520
|
promote: bool = False
|
459
521
|
|