heavyball 1.5.2__tar.gz → 1.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.5.2 → heavyball-1.6.0}/PKG-INFO +4 -6
  2. {heavyball-1.5.2 → heavyball-1.6.0}/README.md +1 -3
  3. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball/__init__.py +73 -17
  4. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball/chainable.py +76 -14
  5. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball/utils.py +322 -175
  6. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/PKG-INFO +4 -6
  7. {heavyball-1.5.2 → heavyball-1.6.0}/setup.py +3 -3
  8. {heavyball-1.5.2 → heavyball-1.6.0}/LICENSE +0 -0
  9. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.5.2 → heavyball-1.6.0}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.5.2 → heavyball-1.6.0}/setup.cfg +0 -0
  14. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_caution.py +0 -0
  18. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_channels_last.py +0 -0
  19. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_closure.py +0 -0
  20. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_ema.py +0 -0
  21. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_foreach.py +0 -0
  22. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_hook.py +0 -0
  23. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_mars.py +0 -0
  24. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_memory.py +0 -0
  25. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_merge.py +0 -0
  26. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_no_grad.py +0 -0
  27. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_psgd.py +0 -0
  28. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_soap.py +0 -0
  29. {heavyball-1.5.2 → heavyball-1.6.0}/test/test_stochastic_updates.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.2
3
+ Version: 1.6.0
4
4
  Summary: Efficient optimizers
5
- Home-page: https://github.com/clashluke/heavyball
6
- Author: Lucas Nestler
5
+ Home-page: https://github.com/HomebrewML/HeavyBall
6
+ Author: HeavyBall Authors
7
7
  Author-email: github.heavyball@nestler.sh
8
8
  License: BSD
9
9
  Classifier: Development Status :: 5 - Production/Stable
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
300
300
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
301
301
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
302
302
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
303
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
303
+ correct_bias: bool = True, warmup_steps: int = 1,
304
304
  split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
305
305
  mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
306
306
  beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
324
324
  * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
325
325
  * **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
326
326
  * **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
327
- * **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
328
327
  * **`correct_bias`**: Enables/disables bias correction for the running averages.
329
328
  * **`warmup_steps`**: Number of steps for linear learning rate warmup.
330
329
  * **`split`**: Whether to split large dimensions when forming the preconditioner.
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
931
930
  * **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
932
931
  `compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
933
932
  configurations.
934
-
@@ -276,7 +276,7 @@ class ForeachSOAP(C.BaseOpt):
276
276
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
277
277
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
278
278
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
279
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
279
+ correct_bias: bool = True, warmup_steps: int = 1,
280
280
  split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
281
281
  mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
282
282
  beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
@@ -300,7 +300,6 @@ the second-order statistics of the gradients to accelerate convergence.
300
300
  * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
301
301
  * **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
302
302
  * **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
303
- * **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
304
303
  * **`correct_bias`**: Enables/disables bias correction for the running averages.
305
304
  * **`warmup_steps`**: Number of steps for linear learning rate warmup.
306
305
  * **`split`**: Whether to split large dimensions when forming the preconditioner.
@@ -907,4 +906,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
907
906
  * **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
908
907
  `compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
909
908
  configurations.
910
-
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
104
104
  Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
105
105
  https://arxiv.org/abs/2409.11321
106
106
  https://github.com/nikhilvyas/SOAP
107
-
108
- ScheduleFree:
109
- The Road Less Scheduled
110
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
111
- https://arxiv.org/abs/2405.15682
112
- https://github.com/facebookresearch/schedule_free
113
107
  """
114
108
  use_precond_schedule: bool = False
115
109
 
116
110
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
117
111
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
118
112
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
119
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
120
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
121
- mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
122
- beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
123
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
114
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
115
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
116
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
117
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
124
118
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
125
119
 
126
120
  defaults = locals()
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
137
131
  C.scale_by_soap)
138
132
 
139
133
 
134
+ class ForeachSignLaProp(C.BaseOpt):
135
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
136
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
137
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
138
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
139
+ defaults = locals()
140
+ defaults.pop("self")
141
+ params = defaults.pop("params")
142
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
143
+
144
+
145
+ class ForeachSOLP(C.BaseOpt):
146
+ """
147
+ ForeachSOLP
148
+
149
+ Sources:
150
+ Baseline SOAP:
151
+ SOAP: Improving and Stabilizing Shampoo using Adam
152
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
153
+ https://arxiv.org/abs/2409.11321
154
+ https://github.com/nikhilvyas/SOAP
155
+ """
156
+ use_precond_schedule: bool = False
157
+
158
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
159
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
160
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
161
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
162
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
163
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
164
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
165
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
166
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
167
+
168
+ defaults = locals()
169
+ defaults.pop("self")
170
+ params = defaults.pop("params")
171
+
172
+ if use_precond_schedule:
173
+ del defaults['precondition_frequency']
174
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
175
+ else:
176
+ del defaults['precond_scheduler']
177
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
178
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
179
+ functools.partial(C.scale_by_soap, inner='laprop'))
180
+
181
+
140
182
  class PaLMForeachSOAP(ForeachSOAP):
141
183
  use_precond_schedule: bool = False
142
184
  palm: bool = True
@@ -163,6 +205,18 @@ class OrthoLaProp(C.BaseOpt):
163
205
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
206
 
165
207
 
208
+ class LaPropOrtho(C.BaseOpt):
209
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
210
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
211
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
212
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
213
+ defaults = locals()
214
+ defaults.pop("self")
215
+ params = defaults.pop("params")
216
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
217
+ C.orthogonalize_grad_to_param)
218
+
219
+
166
220
  class ForeachPSGDKron(C.BaseOpt):
167
221
  """
168
222
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -178,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
178
232
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
179
233
  momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
180
234
  split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
181
- stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
235
+ stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
182
236
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
183
237
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
184
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
238
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
185
239
  # expert parameters
186
240
  precond_init_scale=1.0, precond_lr=0.1):
187
241
  defaults = locals()
@@ -238,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
238
292
  CachedPSGDKron = ForeachCachedPSGDKron
239
293
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
240
294
  Muon = ForeachMuon
295
+ SignLaProp = ForeachSignLaProp
241
296
 
242
297
  __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
243
298
  "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
244
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
245
- "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
246
- "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
247
- "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
299
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
300
+ "ForeachAdamW", "ForeachSFAdamW",
301
+ "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
302
+ "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
303
+ 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import random
3
- from typing import Optional, Union, Literal
3
+ from typing import Optional, Union, Literal, List
4
4
 
5
5
  import torch
6
6
 
@@ -127,7 +127,7 @@ def zero_guard(*names):
127
127
 
128
128
 
129
129
  def copy_guard(index, *names):
130
- return functools.partial(CopyGuard, index=index, names=names, )
130
+ return functools.partial(CopyGuard, index=index, names=names)
131
131
 
132
132
 
133
133
  def general_guard(*names, init_fn, skip_first: bool = True):
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
152
152
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
153
153
 
154
154
 
155
+ @zero_guard('exp_avg')
156
+ @no_state
157
+ def weight_decay_to_ema(group, update, grad, param, exp_avg):
158
+ utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
159
+ group['weight_decay_to_ema'] * group['lr'])
160
+ return update
161
+
162
+
163
+ @zero_guard('exp_avg')
164
+ @no_state
165
+ def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
166
+ utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
167
+ group['weight_decay_to_ema'] * group['lr'])
168
+ return update
169
+
170
+
155
171
  @zero_guard("exp_avg_sq")
156
172
  @no_state
157
173
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
@@ -206,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
206
222
  @no_state
207
223
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
208
224
  if group['step'] == 1:
209
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
225
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
210
226
  raise SkipUpdate
211
227
 
212
228
  if group['step'] == 2:
213
229
  update = utils.promote(update)
214
230
  easq = utils.promote(exp_avg_sq)
215
231
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
216
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
232
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
233
+ group['eps'])
217
234
  raise SkipUpdate
218
235
 
219
236
  utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
@@ -225,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
225
242
  @no_state
226
243
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
227
244
  if group['step'] == 1:
228
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
245
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
229
246
  raise SkipUpdate
230
247
 
231
248
  if group['step'] == 2:
232
249
  update = utils.promote(update)
233
250
  easq = utils.promote(exp_avg_sq)
234
251
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
235
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
252
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
253
+ group['eps'])
236
254
  raise SkipUpdate
237
255
 
238
256
  return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
239
257
 
240
258
 
241
- def _init_soap(state, group, update, grad, param):
242
- utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
259
+ def _init_soap(state, group, update, grad, param, inner: str = ''):
260
+ utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
243
261
 
244
262
 
245
263
  def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
@@ -295,6 +313,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
295
313
  return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
296
314
 
297
315
 
316
+ @zero_guard('momentum')
317
+ @no_state
318
+ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
319
+ return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
320
+
321
+
322
+ def _store_std(state, group, update, grad, param):
323
+ state['init_std'] = torch.std(grad, dim=0)
324
+
325
+
326
+ @general_guard("init_std", init_fn=_store_std)
327
+ @no_state
328
+ def mup_approx(group, updates, grads, params, init_std):
329
+ _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
330
+ _updates, _init_std = zip(*_updates)
331
+ utils.stochastic_multiply_(_updates, _init_std)
332
+ return updates
333
+
334
+
298
335
  @zero_guard("momentum")
299
336
  @no_state
300
337
  def heavyball_momentum(group, updates, grads, params, momentum):
@@ -308,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
308
345
  @general_guard("Q", "GG", init_fn=_init_soap)
309
346
  @no_state
310
347
  def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
311
- update = utils.promote(update)
348
+ update = utils.promote(update) # Promote to highest precision if needed
312
349
 
313
350
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
314
351
  fn = _optim_fns[inner]
315
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
352
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
353
+ group['eps'])
316
354
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
317
355
 
318
- for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
319
- utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
356
+ for u, q, gg, ea in zip(update, Q, GG, exp_avg):
357
+ utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
320
358
  utils.beta_debias(group['shampoo_beta'], group['step']),
321
359
  group['is_preconditioning'])
322
360
  return precond
@@ -414,6 +452,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
414
452
  raise SkipUpdate
415
453
 
416
454
 
455
+ @no_state
456
+ def sign(group, update, grad, param, graft: bool = True):
457
+ return utils.sign_(update, graft)
458
+
459
+
417
460
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
418
461
  @no_state_no_foreach
419
462
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
@@ -439,8 +482,7 @@ def apply_to_idx(fn, idx):
439
482
  return _fn
440
483
 
441
484
 
442
- def chain(state: Union[callable, dict], group, grad, param, *fns):
443
- update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
485
+ def _inner_chain(state, group, update, grad, param, *fns):
444
486
  skip_update = False
445
487
  for fn in fns:
446
488
  try:
@@ -450,10 +492,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
450
492
  continue
451
493
  if update is None:
452
494
  break
495
+ return update, skip_update
496
+
497
+
498
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
499
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
500
+ update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
453
501
  if not skip_update and update is not None:
454
502
  utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
455
503
 
456
504
 
505
+ def create_branch(branches: List[List[callable]], merge_fn: callable):
506
+ def _branch(state, group, update, grad, param):
507
+ outputs = []
508
+ for branch in branches:
509
+ branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
510
+ branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
511
+ if skip_update:
512
+ raise ValueError("Branches should not skip updates")
513
+ outputs.append(branch_update)
514
+ return merge_fn(outputs)
515
+
516
+ return _branch
517
+
518
+
457
519
  class ChainOpt(utils.StatefulOptimizer):
458
520
  promote: bool = False
459
521