heavyball 1.2.1__tar.gz → 1.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.2.1 → heavyball-1.2.2}/PKG-INFO +1 -1
  2. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball/chainable.py +9 -7
  3. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball.egg-info/PKG-INFO +1 -1
  4. {heavyball-1.2.1 → heavyball-1.2.2}/setup.py +1 -1
  5. {heavyball-1.2.1 → heavyball-1.2.2}/LICENSE +0 -0
  6. {heavyball-1.2.1 → heavyball-1.2.2}/README.md +0 -0
  7. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball/__init__.py +0 -0
  8. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball/utils.py +0 -0
  9. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball.egg-info/SOURCES.txt +0 -0
  10. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball.egg-info/dependency_links.txt +0 -0
  11. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball.egg-info/requires.txt +0 -0
  12. {heavyball-1.2.1 → heavyball-1.2.2}/heavyball.egg-info/top_level.txt +0 -0
  13. {heavyball-1.2.1 → heavyball-1.2.2}/setup.cfg +0 -0
  14. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_bf16_params.py +0 -0
  15. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_bf16_q.py +0 -0
  16. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_bf16_storage.py +0 -0
  17. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_caution.py +0 -0
  18. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_channels_last.py +0 -0
  19. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_closure.py +0 -0
  20. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_ema.py +0 -0
  21. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_foreach.py +0 -0
  22. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_hook.py +0 -0
  23. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_mars.py +0 -0
  24. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_memory.py +0 -0
  25. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_merge.py +0 -0
  26. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_no_grad.py +0 -0
  27. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_psgd.py +0 -0
  28. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_soap.py +0 -0
  29. {heavyball-1.2.1 → heavyball-1.2.2}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.1
3
+ Version: 1.2.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -287,16 +287,19 @@ def heavyball_momentum(group, updates, grads, params, momentum):
287
287
  return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
288
288
 
289
289
 
290
+ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
291
+
292
+
290
293
  @zero_guard("exp_avg", "exp_avg_sq")
291
294
  @general_guard("Q", "GG", init_fn=_init_soap)
292
295
  @no_state
293
- def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG):
296
+ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
294
297
  update = utils.promote(update)
295
298
 
296
299
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
297
- precond = utils.adam_(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group),
298
- utils.scalar_guard(group['step'], exp_avg[0]))
299
- precond = [utils.project(p, q, False) for p, q in zip(precond, Q)]
300
+ fn = _optim_fns[inner]
301
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
302
+ precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
300
303
 
301
304
  for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
302
305
  utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
@@ -355,8 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
355
358
  update = update.to(memory_format=torch.contiguous_format)
356
359
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
357
360
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
358
- out = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
359
- return torch.as_strided(out, old.shape, old.stride())
361
+ return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
360
362
 
361
363
 
362
364
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
@@ -512,7 +514,7 @@ class BaseOpt(ChainOpt):
512
514
 
513
515
  fns = tuple(fns)
514
516
 
515
- self.compile_step = default(compile_step, self.compile_step)
517
+ self.compile_step = default(compile_step, self.compile_step)
516
518
  if default(palm, self.palm):
517
519
  fns = (palm_beta2,) + fns
518
520
  if default(gradient_clipping, self.gradient_clipping) is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.2.1
3
+ Version: 1.2.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.2.1',
13
+ version='1.2.2',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes