heavyball 1.1.0__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.1.0 → heavyball-1.1.1}/PKG-INFO +1 -1
  2. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball/chainable.py +8 -7
  3. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball/utils.py +19 -4
  4. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/PKG-INFO +1 -1
  5. {heavyball-1.1.0 → heavyball-1.1.1}/setup.py +1 -1
  6. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_channels_last.py +2 -2
  7. {heavyball-1.1.0 → heavyball-1.1.1}/LICENSE +0 -0
  8. {heavyball-1.1.0 → heavyball-1.1.1}/README.md +0 -0
  9. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball/__init__.py +0 -0
  10. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/SOURCES.txt +0 -0
  11. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/dependency_links.txt +0 -0
  12. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/requires.txt +0 -0
  13. {heavyball-1.1.0 → heavyball-1.1.1}/heavyball.egg-info/top_level.txt +0 -0
  14. {heavyball-1.1.0 → heavyball-1.1.1}/setup.cfg +0 -0
  15. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_bf16_params.py +0 -0
  16. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_bf16_q.py +0 -0
  17. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_bf16_storage.py +0 -0
  18. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_caution.py +0 -0
  19. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_closure.py +0 -0
  20. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_ema.py +0 -0
  21. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_foreach.py +0 -0
  22. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_hook.py +0 -0
  23. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_mars.py +0 -0
  24. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_memory.py +0 -0
  25. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_merge.py +0 -0
  26. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_no_grad.py +0 -0
  27. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_psgd.py +0 -0
  28. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_soap.py +0 -0
  29. {heavyball-1.1.0 → heavyball-1.1.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -140,16 +140,14 @@ class SkipUpdate(ValueError):
140
140
  @zero_guard("exp_avg")
141
141
  @no_state
142
142
  def exp_avg(group, update, grad, param, exp_avg):
143
- utils.stochastic_lerp_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
144
- return exp_avg
143
+ return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
145
144
 
146
145
 
147
146
  @zero_guard("exp_avg_sq")
148
147
  @no_state
149
148
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
150
- out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
151
- group['eps'])
152
- return out
149
+ return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
150
+ group['eps'])
153
151
 
154
152
 
155
153
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -350,9 +348,12 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
350
348
  @no_state_no_foreach
351
349
  def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
352
350
  prob: Optional[callable] = None):
351
+ old = update
352
+ update = update.to(memory_format=torch.contiguous_format)
353
353
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
354
354
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
355
- return _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
355
+ out = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
356
+ return torch.as_strided(out, old.shape, old.stride())
356
357
 
357
358
 
358
359
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
@@ -400,7 +401,7 @@ def apply_to_idx(fn, idx):
400
401
 
401
402
 
402
403
  def chain(state: Union[callable, dict], group, grad, param, *fns):
403
- update = [torch.clone(g) for g in grad]
404
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
404
405
  skip_update = False
405
406
  for fn in fns:
406
407
  try:
@@ -201,6 +201,21 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
201
201
  return grad
202
202
 
203
203
 
204
+ @decorator_knowngood
205
+ def _compilable_exp_avg_(state, grad, beta):
206
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
+ [s.lerp_(g, beta) for s, g in zip(s32, g32)]
208
+ copy_stochastic_list_(state, s32)
209
+ copy_stochastic_list_(grad, s32)
210
+
211
+
212
+ def scale_by_exp_avg_(state, grad, beta):
213
+ state, grad = list_guard(state, grad)
214
+ beta = scalar_guard(beta, state[0])
215
+ _compilable_exp_avg_(state, grad, beta)
216
+ return grad
217
+
218
+
204
219
  @decorator_knowngood
205
220
  def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
206
221
  p_norm = torch._foreach_norm(parameters)
@@ -321,13 +336,13 @@ def nesterov_momentum(state, grad, beta):
321
336
 
322
337
  @decorator_knowngood
323
338
  def inplace_orthogonal_(x, mode, out):
324
- if mode == 'qr':
339
+ if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
340
+ y = zeropower_via_newtonschulz5(x, 5)
341
+ elif mode == 'qr':
325
342
  y = torch.linalg.qr(x).Q
326
343
  elif mode == 'svd':
327
344
  u, s, v = torch.linalg.svd(x)
328
345
  y = u @ v.T
329
- elif mode == 'newtonschulz':
330
- y = zeropower_via_newtonschulz5(x, 5)
331
346
  else:
332
347
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
333
348
  set_(out, y)
@@ -363,7 +378,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
363
378
  est_eig = torch.einsum('ij,ij->j', o, tmp)
364
379
  sort_idx = torch.argsort(est_eig, descending=True)
365
380
  indices.append(sort_idx)
366
- inplace_orthogonal_(tmp[:, sort_idx], q)
381
+ inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q)
367
382
 
368
383
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
369
384
  for i, ind in enumerate(indices))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.1.0',
13
+ version='1.1.1',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -17,7 +17,7 @@ config.cache_size_limit = 128
17
17
 
18
18
  @pytest.mark.parametrize("opt", heavyball.__all__)
19
19
  @pytest.mark.parametrize("size,depth", [(128, 1)])
20
- def test_foreach(opt, size, depth: int, iterations: int = 32, outer_iterations: int = 1):
20
+ def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations: int = 1):
21
21
  set_torch()
22
22
  opt = getattr(heavyball, opt)
23
23
 
@@ -34,7 +34,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 32, outer_iterations:
34
34
  if is_channels_last:
35
35
  model.to(memory_format=torch.channels_last)
36
36
 
37
- o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16)
37
+ o = get_optim(opt, model.parameters(), lr=1e-5, weight_decay=1e-4, warmup_steps=16)
38
38
 
39
39
  for _ in range(iterations):
40
40
  loss = model(torch.randn((1024, size, 4, 4), device='cuda')).square().mean()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes