heavyball 0.22.0__tar.gz → 0.23.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {heavyball-0.22.0 → heavyball-0.23.0}/PKG-INFO +2 -2
  2. {heavyball-0.22.0 → heavyball-0.23.0}/README.md +1 -1
  3. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/delayed_psgd.py +6 -6
  4. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/p_adam.py +2 -2
  5. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/psgd_kron.py +1 -1
  6. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/pure_psgd.py +1 -1
  7. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/utils.py +84 -85
  8. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball.egg-info/PKG-INFO +2 -2
  9. {heavyball-0.22.0 → heavyball-0.23.0}/setup.py +1 -1
  10. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_bf16_params.py +0 -8
  11. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_bf16_q.py +0 -8
  12. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_bf16_storage.py +0 -6
  13. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_caution.py +0 -9
  14. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_mars.py +3 -11
  15. {heavyball-0.22.0 → heavyball-0.23.0}/LICENSE +0 -0
  16. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/__init__.py +0 -0
  17. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/cached_delayed_psgd_kron.py +0 -0
  18. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/cached_psgd_kron.py +0 -0
  19. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/foreach_adamw.py +0 -0
  20. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/foreach_adopt.py +0 -0
  21. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/foreach_laprop.py +0 -0
  22. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/foreach_sfadamw.py +0 -0
  23. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/foreach_soap.py +0 -0
  24. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/palm_foreach_sfadamw.py +0 -0
  25. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/palm_foreach_soap.py +0 -0
  26. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/precond_schedule_foreach_soap.py +0 -0
  27. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
  28. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/precond_schedule_sfpsoap.py +0 -0
  29. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
  30. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball.egg-info/SOURCES.txt +0 -0
  31. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball.egg-info/dependency_links.txt +0 -0
  32. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball.egg-info/requires.txt +0 -0
  33. {heavyball-0.22.0 → heavyball-0.23.0}/heavyball.egg-info/top_level.txt +0 -0
  34. {heavyball-0.22.0 → heavyball-0.23.0}/setup.cfg +0 -0
  35. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_closure.py +0 -0
  36. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_ema.py +0 -0
  37. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_foreach.py +0 -0
  38. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_memory.py +0 -0
  39. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_merge.py +0 -0
  40. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_no_grad.py +0 -0
  41. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_psgd.py +0 -0
  42. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_soap.py +0 -0
  43. {heavyball-0.22.0 → heavyball-0.23.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.22.0
3
+ Version: 0.23.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
8
8
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
9
9
  largely static alternative to `torch.optim` with more and better optimizers.
10
10
 
11
- Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
11
+ Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
12
12
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
13
13
 
14
14
  ## Features
@@ -5,16 +5,16 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
5
5
  """
6
6
 
7
7
  import torch
8
- from heavyball.utils import stochastic_lerp_, beta_debias
8
+ from heavyball.utils import stochastic_lerp_, beta_debias, stochastic_add_
9
9
 
10
10
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
11
- triu_to_line, line_to_triu, promote
11
+ triu_to_line, line_to_triu, promote,_compilable_update_
12
12
 
13
13
 
14
14
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
15
- def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_deca, clip_fn, caution, grad):
16
- new = psgd_precond_grad(q, exprs, ea)
17
- update_param_([p], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
15
+ def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn, caution, grad):
16
+ new = psgd_precond_grad(False, exprs, ea, *q)
17
+ _compilable_update_([p], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
18
18
 
19
19
 
20
20
  class ForeachDelayedPSGD(PSGDBase):
@@ -114,7 +114,7 @@ class ForeachDelayedPSGD(PSGDBase):
114
114
  q_orig = Q_list.pop(0)
115
115
  ea = exp_avg_list.pop(0)
116
116
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
117
- _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
117
+ _compilable_psgd_precond_grad_(q, self.state_(p)["exprs"][-1], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
118
118
  g)
119
119
  if should_update:
120
120
  q32 = [promote(q_) for q_ in q]
@@ -110,8 +110,8 @@ class ForeachPaLMPAdam(PSGDBase):
110
110
 
111
111
  for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
112
112
  gc = g.clone() if group['caution'] else None
113
- psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
114
- ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
113
+ psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *Q)
114
+ ea = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *Q)
115
115
  exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
116
116
  torch.div(ea, g, out=g)
117
117
  """
@@ -116,5 +116,5 @@ class ForeachPSGDKron(PSGDBase):
116
116
  q32 = [promote(q_) for q_ in q]
117
117
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
118
118
  store_triu_as_line)
119
- g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
119
+ g = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *q)
120
120
  update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
@@ -101,5 +101,5 @@ class ForeachPurePSGD(PSGDBase):
101
101
  if group:
102
102
  q32 = [promote(q_) for q_ in q]
103
103
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
104
- psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
104
+ psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *q)
105
105
  update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple, Callable, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
+ from torch import Tensor
10
11
  from torch.backends import cudnn, opt_einsum
11
12
  from torch.utils._pytree import tree_map
12
13
 
@@ -39,15 +40,14 @@ def warmup(lr: float, step: int, warmup_steps: int):
39
40
 
40
41
 
41
42
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
42
- def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
- p32 = promote(p)
44
- z32 = promote(z)
45
- p32.lerp_(end=z32, weight=ckp1)
46
- p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
47
- copy_stochastic_(p, p32)
48
-
49
- z32.add_(grad, alpha=-lr)
50
- copy_stochastic_(z, z32)
43
+ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor, beta1: Tensor):
44
+ p32, z32, g32 = [promote(x) for x in (p, z, grad)]
45
+ for p_, z_, g_ in zip(p32, z32, g32):
46
+ p_.lerp_(z_, ckp1)
47
+ p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
48
+ z_.add(g_, alpha=-lr)
49
+ copy_stochastic_list_(p, p32)
50
+ copy_stochastic_list_(z, z32)
51
51
 
52
52
 
53
53
  def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
@@ -61,8 +61,8 @@ def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
61
61
  return ckp1, weight_sum
62
62
 
63
63
 
64
- def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
65
- z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
64
+ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
65
+ z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
66
66
  weight = lr ** weight_lr_power * max(step, 1) ** r
67
67
  weight_sum = weight_sum + weight
68
68
 
@@ -73,10 +73,8 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
73
73
 
74
74
  # These operations update y in-place,
75
75
  # without computing x explicitly.
76
- lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
77
- ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
78
- for p, z_, g in zip(parameters, z, grad):
79
- _compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
76
+ lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
77
+ _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
80
78
  return weight_sum
81
79
 
82
80
 
@@ -142,27 +140,25 @@ def beta_debias(beta, step):
142
140
 
143
141
 
144
142
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
145
- def _compilable_exp_avg_sq_(state, grad, beta2, eps, out=None):
143
+ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]):
146
144
  torch._foreach_mul_(state, beta2)
147
145
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
148
146
  denom = torch._foreach_sqrt(state)
149
147
  [denom.clamp_(min=eps) for denom in denom]
150
- if out is not None:
151
- copy_stochastic_list_(out, denom)
152
- return out
148
+ if out[0] is None:
149
+ return denom
153
150
 
154
- return denom
151
+ copy_stochastic_list_(out, denom)
152
+ return out
155
153
 
156
154
 
157
155
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
158
- state, grad = list_guard(state), list_guard(grad)
159
- if not isinstance(beta2, torch.Tensor):
160
- beta2 = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(beta2)
161
- if not isinstance(eps, torch.Tensor):
162
- eps = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(eps)
156
+ state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
157
+ beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
163
158
  return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
164
159
 
165
- def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[torch.Tensor], clip_val: float,
160
+
161
+ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
166
162
  minimum: float = 1e-3, eps: float = 1e-8):
167
163
  if clip_val <= 0:
168
164
  return
@@ -183,7 +179,7 @@ def is_compiling():
183
179
  return True
184
180
 
185
181
 
186
- def set_(dst: torch.Tensor, src: torch.Tensor):
182
+ def set_(dst: Tensor, src: Tensor):
187
183
  if not is_compiling() and src.data_ptr() == dst.data_ptr():
188
184
  return
189
185
  if src.shape != dst.shape:
@@ -344,7 +340,7 @@ def get_orthogonal_matrix(mat):
344
340
 
345
341
 
346
342
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
347
- def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
343
+ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
348
344
  for x_, y_ in zip(x, y):
349
345
  x32 = promote(x_)
350
346
  y32 = promote(y_)
@@ -352,10 +348,9 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
352
348
  copy_stochastic_(x_, x32)
353
349
 
354
350
 
355
- def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
351
+ def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
356
352
  x, y = list_guard(x), list_guard(y)
357
- if not isinstance(a, torch.Tensor):
358
- a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
353
+ a = scalar_guard(a, x[0])
359
354
  _compilable_stochastic_lerp_(x, y, a)
360
355
 
361
356
 
@@ -365,8 +360,16 @@ def list_guard(x):
365
360
  return [x]
366
361
 
367
362
 
363
+ def scalar_guard(x, ref):
364
+ if isinstance(x, float):
365
+ return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
366
+ if isinstance(x, int):
367
+ return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
368
+ return x
369
+
370
+
368
371
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
369
- def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
372
+ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
370
373
  for x_, y_ in zip(x, y):
371
374
  x32 = promote(x_)
372
375
  y32 = promote(y_)
@@ -374,10 +377,9 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
374
377
  copy_stochastic_(x_, x32)
375
378
 
376
379
 
377
- def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
380
+ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
378
381
  x, y = list_guard(x), list_guard(y)
379
- if not isinstance(alpha, torch.Tensor):
380
- alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
382
+ alpha = scalar_guard(alpha, x[0])
381
383
  _compilable_stochastic_add_(x, y, alpha)
382
384
 
383
385
 
@@ -399,12 +401,12 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
399
401
  def promote(x):
400
402
  if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
401
403
  return torch.float32
402
- if isinstance(x, torch.Tensor) and x.dtype in (torch.bfloat16, torch.float16):
404
+ if isinstance(x, Tensor) and x.dtype in (torch.bfloat16, torch.float16):
403
405
  return x.float()
404
406
  return x
405
407
 
406
408
 
407
- def min_dtype(xs: List[torch.Tensor]):
409
+ def min_dtype(xs: List[Tensor]):
408
410
  dtypes = [x.dtype for x in xs]
409
411
  for d in (torch.float32, torch.bfloat16, torch.float16):
410
412
  if all(x in (d, torch.float32, torch.float64) for x in dtypes):
@@ -470,7 +472,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
470
472
  self.fake_groups = {}
471
473
  self.use_ema = use_ema
472
474
 
473
- def key(self, param: torch.Tensor):
475
+ def key(self, param: Tensor):
474
476
  return (param.data_ptr(), tuple(param.shape))
475
477
 
476
478
  def get_groups(self, group):
@@ -483,7 +485,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
483
485
 
484
486
  return [self.fake_groups[self.key(p)] for p in group['params']]
485
487
 
486
- def state_(self, arg: torch.Tensor):
488
+ def state_(self, arg: Tensor):
487
489
  return self.state[self.key(arg)]
488
490
 
489
491
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
@@ -515,7 +517,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
515
517
  p_views = merge_group(group, p)
516
518
  if grad is not None:
517
519
  grad = merge_group(group, grad)
518
- if isinstance(p_views, torch.Tensor):
520
+ if isinstance(p_views, Tensor):
519
521
  yield p_views, grad
520
522
  continue
521
523
  if grad is None:
@@ -528,7 +530,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
528
530
 
529
531
  def _add(x):
530
532
  nonlocal total_bytes
531
- if isinstance(x, torch.Tensor):
533
+ if isinstance(x, Tensor):
532
534
  total_bytes += x.numel() * x.element_size()
533
535
 
534
536
  for group in self.param_groups:
@@ -636,13 +638,14 @@ class ScheduleFree(StatefulOptimizer):
636
638
  raise NotImplementedError
637
639
 
638
640
 
639
- def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
641
+ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
640
642
  for t, s in zip(target, source):
641
643
  copy_stochastic_(t, s)
642
644
 
643
645
 
644
646
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
645
- def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
647
+ def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
648
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
646
649
  beta1 = beta_debias(beta1, step)
647
650
  beta2 = beta_debias(beta2, step)
648
651
 
@@ -655,21 +658,17 @@ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2
655
658
  return denom
656
659
 
657
660
 
658
- def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
659
- grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
660
- if isinstance(beta1, float):
661
- beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
662
- if isinstance(beta2, float):
663
- beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
664
- if isinstance(step, int):
665
- step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
661
+ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
662
+ beta1: float, beta2: float, step: int):
663
+ exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
664
+ grad), list_guard(grad_projected)
665
+ beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
666
666
  denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
667
667
  return denom
668
668
 
669
669
 
670
- # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
671
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
672
- def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
670
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
671
+ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
673
672
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
674
673
  # create a random 16 bit integer
675
674
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
@@ -684,7 +683,7 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
684
683
  target.copy_(result.view(dtype=torch.float32))
685
684
 
686
685
 
687
- def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
686
+ def copy_stochastic_(target: Tensor, source: Tensor):
688
687
  if not is_compiling() and target.data_ptr() == source.data_ptr():
689
688
  return
690
689
  if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
@@ -693,31 +692,31 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
693
692
 
694
693
 
695
694
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
696
- def _compilable_update_(p, u, decay, add_fn, lr, caution, g):
695
+ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
696
+ g: List[Optional[Tensor]]):
697
697
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
698
- p32, u32, g32 = [list(map(promote, x)) for x in [p, u, g]]
698
+ p32, u32 = [list(map(promote, x)) for x in [p, u]]
699
699
 
700
700
  if decay > 0:
701
701
  torch._foreach_mul_(p32, 1 - decay * lr)
702
702
 
703
- for p32_, u32_, g32_ in zip(p32, u32, g32): # lr is data-dependent -> can't compile a foreach
703
+ for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
704
704
  if caution:
705
- _compilable_cautioning_(g32_, u32_)
706
- if add_fn is None:
707
- p32_.add_(u32_, alpha=lr)
708
- else:
709
- add_fn(p32_, u32_, lr)
705
+ _compilable_cautioning_(promote(g_), u32_)
706
+ add_fn(p32_, u32_, lr)
710
707
 
711
708
  copy_stochastic_list_(p, p32)
712
709
 
713
710
 
714
- def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
715
- add_fn: callable = None, caution: bool = False, grad: List[torch.Tensor] = None):
716
- lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
711
+ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
712
+ caution: bool = False, grad: List[Tensor] = None):
717
713
  param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
714
+ lr = scalar_guard(lr, param[0])
718
715
  if not caution:
719
716
  grad = [None] * len(param)
720
- _compilable_update_(param, update, decay, add_fn, lr_tensor, caution, grad)
717
+ if add_fn is None:
718
+ add_fn = stochastic_add_
719
+ _compilable_update_(param, update, decay, add_fn, lr, caution, grad)
721
720
 
722
721
 
723
722
  def precond_schedule(step, precond_scheduler, rng):
@@ -887,14 +886,14 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
887
886
 
888
887
 
889
888
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
890
- def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
889
+ def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
891
890
  """Precondition gradient G with preconditioner Q."""
892
- md = min_dtype(Q)
893
- out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
891
+ md = min_dtype(preconds)
892
+ out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
894
893
  if inplace:
895
- set_(G, out)
896
- return G
897
- return out.to(G.dtype)
894
+ set_(grad, out)
895
+ return grad
896
+ return out.to(grad.dtype)
898
897
 
899
898
 
900
899
  def norm_clip_(x, scale=None):
@@ -957,7 +956,7 @@ def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
957
956
 
958
957
 
959
958
  @decorator
960
- def triu_to_line(Q_list: List[torch.Tensor]):
959
+ def triu_to_line(Q_list: List[Tensor]):
961
960
  out = []
962
961
  for q in Q_list:
963
962
  if q.dim() < 2:
@@ -974,7 +973,7 @@ def _triu_shape(numel):
974
973
 
975
974
 
976
975
  @decorator
977
- def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
976
+ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
978
977
  new = []
979
978
  for shape, q in Q_list:
980
979
  if shape is not None:
@@ -1031,22 +1030,22 @@ class PSGDBase(StatefulOptimizer):
1031
1030
 
1032
1031
 
1033
1032
  # TODO: Figure out why this sometimes crashes
1034
- # @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1035
- def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad):
1033
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1034
+ def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
1035
+ clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
1036
1036
  md = min_dtype(cached_q + [ea])
1037
1037
  new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
1038
1038
  update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
1039
1039
 
1040
1040
 
1041
- def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
1042
- weight_decay: float, clip_fn, caution, grad):
1043
- if isinstance(lr, float):
1044
- lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
1045
- _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad)
1041
+ def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
1042
+ clip_fn, caution, grad):
1043
+ lr = scalar_guard(lr, param)
1044
+ _compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
1046
1045
 
1047
1046
 
1048
1047
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1049
- def _compilable_mars_correction_(g, old_g, a):
1048
+ def _compilable_mars_correction_(g: Tensor, old_g: Tensor, a: Tensor):
1050
1049
  g_copy = [g_.clone() for g_ in g]
1051
1050
  _compilable_stochastic_lerp_(g, old_g, a)
1052
1051
  copy_stochastic_list_(old_g, g_copy)
@@ -1055,12 +1054,12 @@ def _compilable_mars_correction_(g, old_g, a):
1055
1054
  def mars_correction(g, old_g, beta1, gamma):
1056
1055
  a = -gamma * beta1 / (1 - beta1)
1057
1056
  g, old_g = list_guard(g), list_guard(old_g)
1058
- a = torch.empty((), dtype=torch.float32, device=g[0].device).fill_(a)
1057
+ a = scalar_guard(a, g[0])
1059
1058
  _compilable_mars_correction_(g, old_g, a)
1060
1059
 
1061
1060
 
1062
1061
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1063
- def _compilable_cautioning_(g, update):
1062
+ def _compilable_cautioning_(g: Tensor, update: Tensor):
1064
1063
  mask = (g * update) > 0
1065
1064
  update.masked_fill_(~mask, 0)
1066
1065
  scale = mask.numel() / mask.sum().clamp(min=1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.22.0
3
+ Version: 0.23.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-26, 0.22.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='0.22.0',
13
+ version='0.23.0',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -10,14 +10,6 @@ import torch._inductor.config as ind_cfg
10
10
 
11
11
  config.cache_size_limit = 128
12
12
 
13
-
14
- def get_memory():
15
- clean()
16
- torch.cuda.synchronize()
17
- clean()
18
- torch.cuda.synchronize()
19
- return torch.cuda.memory_allocated()
20
-
21
13
  @pytest.mark.parametrize("opt", ['CachedDelayedPSGDKron'])
22
14
  @pytest.mark.parametrize("size,depth", [(256, 1)])
23
15
  def test_foreach(opt, size, depth: int, iterations: int = 16, outer_iterations: int = 3):
@@ -11,14 +11,6 @@ from heavyball.utils import clean, set_torch, PSGDBase
11
11
  config.cache_size_limit = 128
12
12
 
13
13
 
14
- def get_memory():
15
- clean()
16
- torch.cuda.synchronize()
17
- clean()
18
- torch.cuda.synchronize()
19
- return torch.cuda.memory_allocated()
20
-
21
-
22
14
  @pytest.mark.parametrize("opt", heavyball.__all__)
23
15
  @pytest.mark.parametrize("size,depth", [(256, 2)])
24
16
  def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
@@ -10,12 +10,6 @@ from heavyball.utils import clean, set_torch, PSGDBase
10
10
 
11
11
  config.cache_size_limit = 128
12
12
 
13
- def get_memory():
14
- clean()
15
- torch.cuda.synchronize()
16
- clean()
17
- torch.cuda.synchronize()
18
- return torch.cuda.memory_allocated()
19
13
 
20
14
 
21
15
  @pytest.mark.parametrize("opt", heavyball.__all__)
@@ -9,15 +9,6 @@ from torch._dynamo import config
9
9
 
10
10
  config.cache_size_limit = 128
11
11
 
12
-
13
- def get_memory():
14
- clean()
15
- torch.cuda.synchronize()
16
- clean()
17
- torch.cuda.synchronize()
18
- return torch.cuda.memory_allocated()
19
-
20
-
21
12
  @pytest.mark.parametrize("opt", heavyball.__all__)
22
13
  @pytest.mark.parametrize("size,depth", [(128, 2)])
23
14
  def test_caution(opt, size, depth: int, iterations: int = 65536, outer_iterations: int = 2):
@@ -10,17 +10,9 @@ from torch._dynamo import config
10
10
  config.cache_size_limit = 128
11
11
 
12
12
 
13
- def get_memory():
14
- clean()
15
- torch.cuda.synchronize()
16
- clean()
17
- torch.cuda.synchronize()
18
- return torch.cuda.memory_allocated()
19
-
20
-
21
13
  @pytest.mark.parametrize("opt", heavyball.__all__)
22
14
  @pytest.mark.parametrize("size,depth", [(128, 2)])
23
- def test_mars(opt, size, depth: int, iterations: int = 1024, outer_iterations: int = 2):
15
+ def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations: int = 2):
24
16
  set_torch()
25
17
  opt = getattr(heavyball, opt)
26
18
  if ScheduleFree in opt.__mro__:
@@ -35,11 +27,11 @@ def test_mars(opt, size, depth: int, iterations: int = 1024, outer_iterations: i
35
27
  losses.append([])
36
28
 
37
29
  for i in range(outer_iterations):
38
- model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
30
+ model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().double()
39
31
  o = get_optim(opt, model.parameters(), lr=1e-5, mars=mars)
40
32
 
41
33
  for _ in range(iterations):
42
- loss = model(torch.randn((1024, size), device='cuda')).square().mean()
34
+ loss = model(torch.randn((1024, size), device='cuda', dtype=torch.double)).square().mean()
43
35
  loss.backward()
44
36
  o.step()
45
37
  o.zero_grad()
File without changes
File without changes
File without changes
File without changes
File without changes