heavyball 1.5.1__tar.gz → 1.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.5.1 → heavyball-1.5.2}/PKG-INFO +1 -1
  2. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball/__init__.py +1 -36
  3. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball/chainable.py +10 -5
  4. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball/utils.py +11 -6
  5. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball.egg-info/PKG-INFO +1 -1
  6. {heavyball-1.5.1 → heavyball-1.5.2}/setup.py +1 -1
  7. {heavyball-1.5.1 → heavyball-1.5.2}/LICENSE +0 -0
  8. {heavyball-1.5.1 → heavyball-1.5.2}/README.md +0 -0
  9. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.5.1 → heavyball-1.5.2}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.5.1 → heavyball-1.5.2}/setup.cfg +0 -0
  14. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_caution.py +0 -0
  18. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_channels_last.py +0 -0
  19. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_closure.py +0 -0
  20. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_ema.py +0 -0
  21. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_foreach.py +0 -0
  22. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_hook.py +0 -0
  23. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_mars.py +0 -0
  24. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_memory.py +0 -0
  25. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_merge.py +0 -0
  26. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_no_grad.py +0 -0
  27. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_psgd.py +0 -0
  28. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_soap.py +0 -0
  29. {heavyball-1.5.1 → heavyball-1.5.2}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.1
3
+ Version: 1.5.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -163,41 +163,6 @@ class OrthoLaProp(C.BaseOpt):
163
163
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
164
 
165
165
 
166
- class ForeachAdamW(C.BaseOpt):
167
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
168
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
169
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
170
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
171
- defaults = locals()
172
- defaults.pop("self")
173
- params = defaults.pop("params")
174
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
175
-
176
-
177
- class OrthoAdamW(C.BaseOpt):
178
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
179
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
180
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
181
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
182
- defaults = locals()
183
- defaults.pop("self")
184
- params = defaults.pop("params")
185
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
186
- C.orthogonalize_grad_to_param, C.scale_by_adam)
187
-
188
-
189
- class AdamWOrtho(C.BaseOpt):
190
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
191
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
192
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
193
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
194
- defaults = locals()
195
- defaults.pop("self")
196
- params = defaults.pop("params")
197
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
198
- C.orthogonalize_grad_to_param)
199
-
200
-
201
166
  class ForeachPSGDKron(C.BaseOpt):
202
167
  """
203
168
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -216,7 +181,7 @@ class ForeachPSGDKron(C.BaseOpt):
216
181
  stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
217
182
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
218
183
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
219
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
184
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
220
185
  # expert parameters
221
186
  precond_init_scale=1.0, precond_lr=0.1):
222
187
  defaults = locals()
@@ -364,10 +364,12 @@ def _update_psgd_cache(cached, Q_cache, q):
364
364
  return Q_cache
365
365
 
366
366
 
367
- def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
367
+ def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
368
368
  if group.get('is_cached', False):
369
- return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
370
- return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
369
+ out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
370
+ out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
371
+ group['caution'] = False # we already cautioned here - shouldn't do it again
372
+ return out
371
373
 
372
374
 
373
375
  def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
@@ -387,7 +389,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
387
389
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
388
390
  Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
389
391
  update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
390
- return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
392
+ return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
391
393
 
392
394
 
393
395
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
@@ -395,7 +397,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
395
397
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
396
398
  prob: Optional[callable] = None):
397
399
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
398
- precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
400
+ precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
399
401
  _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
400
402
  Q_mat, Q, exprs, prob)
401
403
  return precond
@@ -467,6 +469,8 @@ class ChainOpt(utils.StatefulOptimizer):
467
469
  f'only supported with foreach=True (currently foreach={group["foreach"]}).')
468
470
  group['base_lr'] = group['lr']
469
471
 
472
+ caution = group['caution']
473
+
470
474
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
471
475
 
472
476
  if not vals:
@@ -492,6 +496,7 @@ class ChainOpt(utils.StatefulOptimizer):
492
496
  else:
493
497
  chain(self.state_, group, g, p, *self.fns)
494
498
 
499
+ group['caution'] = caution
495
500
  group['lr'] = group['prev_lr']
496
501
  group['step'] = None
497
502
 
@@ -1300,7 +1300,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1300
1300
 
1301
1301
 
1302
1302
  @decorator_knowngood
1303
- def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1303
+ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1304
+ cast: bool = True):
1305
+ if caution:
1306
+ ea = _compilable_cautioning(grad, ea)
1304
1307
  md = min_dtype(list(cached_q) + [ea])
1305
1308
  args = [q.to(md) for q in cached_q]
1306
1309
  args = args + [ea.to(md)]
@@ -1312,8 +1315,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
1312
1315
 
1313
1316
  @decorator_knowngood
1314
1317
  def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1315
- precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1316
- update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1318
+ precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
1319
+ update_param_(param, precond, lr, decay, caution=False)
1317
1320
 
1318
1321
 
1319
1322
  def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
@@ -1322,7 +1325,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
1322
1325
 
1323
1326
 
1324
1327
  @decorator_knowngood
1325
- def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1328
+ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
1329
+ if caution:
1330
+ ea = _compilable_cautioning(grad, ea)
1326
1331
  md = min_dtype(list(preconds) + [ea])
1327
1332
  args = [q.to(md) for q in preconds]
1328
1333
  args = args + args + [ea.to(md)]
@@ -1332,8 +1337,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1332
1337
 
1333
1338
  @decorator_knowngood
1334
1339
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1335
- precond = psgd_precond_grad(expr, ea, *preconds)
1336
- update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1340
+ precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
1341
+ update_param_(param, precond, lr, decay, caution=False, grad=grad)
1337
1342
 
1338
1343
 
1339
1344
  def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.1
3
+ Version: 1.5.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.5.1',
13
+ version='1.5.2',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes