heavyball 0.21.8__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple, Callable, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
+ from torch import Tensor
10
11
  from torch.backends import cudnn, opt_einsum
11
12
  from torch.utils._pytree import tree_map
12
13
 
@@ -39,15 +40,14 @@ def warmup(lr: float, step: int, warmup_steps: int):
39
40
 
40
41
 
41
42
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
42
- def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
43
- p32 = promote(p)
44
- z32 = promote(z)
45
- p32.lerp_(end=z32, weight=ckp1)
46
- p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
47
- copy_stochastic_(p, p32)
48
-
49
- z32.add_(grad, alpha=-lr)
50
- copy_stochastic_(z, z32)
43
+ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor, beta1: Tensor):
44
+ p32, z32, g32 = [promote(x) for x in (p, z, grad)]
45
+ for p_, z_, g_ in zip(p32, z32, g32):
46
+ p_.lerp_(z_, ckp1)
47
+ p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
48
+ z_.add(g_, alpha=-lr)
49
+ copy_stochastic_list_(p, p32)
50
+ copy_stochastic_list_(z, z32)
51
51
 
52
52
 
53
53
  def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
@@ -61,8 +61,8 @@ def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
61
61
  return ckp1, weight_sum
62
62
 
63
63
 
64
- def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
65
- z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
64
+ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
65
+ z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
66
66
  weight = lr ** weight_lr_power * max(step, 1) ** r
67
67
  weight_sum = weight_sum + weight
68
68
 
@@ -73,10 +73,8 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
73
73
 
74
74
  # These operations update y in-place,
75
75
  # without computing x explicitly.
76
- lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
77
- ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
78
- for p, z_, g in zip(parameters, z, grad):
79
- _compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
76
+ lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
77
+ _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
80
78
  return weight_sum
81
79
 
82
80
 
@@ -142,19 +140,25 @@ def beta_debias(beta, step):
142
140
 
143
141
 
144
142
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
145
- def exp_avg_sq_(state, grad, beta2, eps, out=None):
146
- if isinstance(state, torch.Tensor):
147
- state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
148
- return torch.sqrt(state, out=out).clamp_(min=eps)
149
-
143
+ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]):
150
144
  torch._foreach_mul_(state, beta2)
151
145
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
152
146
  denom = torch._foreach_sqrt(state)
153
- torch._foreach_maximum_(denom, eps)
154
- return denom
147
+ [denom.clamp_(min=eps) for denom in denom]
148
+ if out[0] is None:
149
+ return denom
155
150
 
151
+ copy_stochastic_list_(out, denom)
152
+ return out
156
153
 
157
- def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[torch.Tensor], clip_val: float,
154
+
155
+ def exp_avg_sq_(state, grad, beta2, eps, out=None):
156
+ state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
157
+ beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
158
+ return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
159
+
160
+
161
+ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
158
162
  minimum: float = 1e-3, eps: float = 1e-8):
159
163
  if clip_val <= 0:
160
164
  return
@@ -168,12 +172,19 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
168
172
  torch._foreach_mul_(gradients, p_norm)
169
173
 
170
174
 
171
- def set_(dst: torch.Tensor, src: torch.Tensor):
172
- if not torch.compiler.is_compiling() and src.data_ptr() == dst.data_ptr():
175
+ def is_compiling():
176
+ try:
177
+ return torch.compiler.is_compiling()
178
+ except AttributeError:
179
+ return True
180
+
181
+
182
+ def set_(dst: Tensor, src: Tensor):
183
+ if not is_compiling() and src.data_ptr() == dst.data_ptr():
173
184
  return
174
185
  if src.shape != dst.shape:
175
186
  src = src.reshape_as(dst)
176
- if not torch.compiler.is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
187
+ if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
177
188
  dst.set_(src)
178
189
  else:
179
190
  dst.copy_(src)
@@ -329,7 +340,7 @@ def get_orthogonal_matrix(mat):
329
340
 
330
341
 
331
342
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
332
- def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
343
+ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
333
344
  for x_, y_ in zip(x, y):
334
345
  x32 = promote(x_)
335
346
  y32 = promote(y_)
@@ -337,14 +348,28 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
337
348
  copy_stochastic_(x_, x32)
338
349
 
339
350
 
340
- def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
341
- if not isinstance(a, torch.Tensor):
342
- a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
351
+ def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
352
+ x, y = list_guard(x), list_guard(y)
353
+ a = scalar_guard(a, x[0])
343
354
  _compilable_stochastic_lerp_(x, y, a)
344
355
 
345
356
 
357
+ def list_guard(x):
358
+ if isinstance(x, (list, tuple)):
359
+ return x
360
+ return [x]
361
+
362
+
363
+ def scalar_guard(x, ref):
364
+ if isinstance(x, float):
365
+ return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
366
+ if isinstance(x, int):
367
+ return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
368
+ return x
369
+
370
+
346
371
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
347
- def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
372
+ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
348
373
  for x_, y_ in zip(x, y):
349
374
  x32 = promote(x_)
350
375
  y32 = promote(y_)
@@ -352,9 +377,9 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
352
377
  copy_stochastic_(x_, x32)
353
378
 
354
379
 
355
- def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
356
- if not isinstance(alpha, torch.Tensor):
357
- alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
380
+ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
381
+ x, y = list_guard(x), list_guard(y)
382
+ alpha = scalar_guard(alpha, x[0])
358
383
  _compilable_stochastic_add_(x, y, alpha)
359
384
 
360
385
 
@@ -376,12 +401,12 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
376
401
  def promote(x):
377
402
  if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
378
403
  return torch.float32
379
- if isinstance(x, torch.Tensor) and x.dtype in (torch.bfloat16, torch.float16):
404
+ if isinstance(x, Tensor) and x.dtype in (torch.bfloat16, torch.float16):
380
405
  return x.float()
381
406
  return x
382
407
 
383
408
 
384
- def min_dtype(xs: List[torch.Tensor]):
409
+ def min_dtype(xs: List[Tensor]):
385
410
  dtypes = [x.dtype for x in xs]
386
411
  for d in (torch.float32, torch.bfloat16, torch.float16):
387
412
  if all(x in (d, torch.float32, torch.float64) for x in dtypes):
@@ -447,7 +472,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
447
472
  self.fake_groups = {}
448
473
  self.use_ema = use_ema
449
474
 
450
- def key(self, param: torch.Tensor):
475
+ def key(self, param: Tensor):
451
476
  return (param.data_ptr(), tuple(param.shape))
452
477
 
453
478
  def get_groups(self, group):
@@ -460,19 +485,56 @@ class StatefulOptimizer(torch.optim.Optimizer):
460
485
 
461
486
  return [self.fake_groups[self.key(p)] for p in group['params']]
462
487
 
463
- def state_(self, arg: torch.Tensor):
488
+ def state_(self, arg: Tensor):
464
489
  return self.state[self.key(arg)]
465
490
 
491
+ def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
492
+ for p, g in zip(p_list, g_list):
493
+ state = self.state_(p)
494
+ if 'mars_old_grad' not in state:
495
+ state['mars_old_grad'] = torch.zeros_like(g)
496
+ old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
497
+ mars_correction(g_list, old_gs, mars_gamma, beta)
498
+
499
+ def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
500
+ beta1: float = -1.0):
501
+ for p in group["params"]:
502
+ if skip_none and p.grad is None:
503
+ continue
504
+
505
+ if p.grad is None:
506
+ grad = None
507
+ else:
508
+ if should_promote:
509
+ grad = promote(p.grad)
510
+ else:
511
+ grad = p.grad
512
+ if beta1 >= 0 and group.get('mars', False):
513
+ self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
514
+
515
+ p.grad = None
516
+
517
+ p_views = merge_group(group, p)
518
+ if grad is not None:
519
+ grad = merge_group(group, grad)
520
+ if isinstance(p_views, Tensor):
521
+ yield p_views, grad
522
+ continue
523
+ if grad is None:
524
+ yield from zip(p_views, [None] * len(p_views))
525
+ continue
526
+ yield from zip(p_views, grad)
527
+
466
528
  def state_size(self) -> int:
467
529
  total_bytes = 0
468
530
 
469
531
  def _add(x):
470
532
  nonlocal total_bytes
471
- if isinstance(x, torch.Tensor):
533
+ if isinstance(x, Tensor):
472
534
  total_bytes += x.numel() * x.element_size()
473
535
 
474
536
  for group in self.param_groups:
475
- for p, _ in split_p_and_g_in_group(group, skip_none=False):
537
+ for p, _ in self.split_p_and_g_in_group(group, skip_none=False):
476
538
  tree_map(_add, self.state_(p))
477
539
  return total_bytes
478
540
 
@@ -576,13 +638,14 @@ class ScheduleFree(StatefulOptimizer):
576
638
  raise NotImplementedError
577
639
 
578
640
 
579
- def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]):
641
+ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
580
642
  for t, s in zip(target, source):
581
643
  copy_stochastic_(t, s)
582
644
 
583
645
 
584
646
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
585
- def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
647
+ def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
648
+ grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
586
649
  beta1 = beta_debias(beta1, step)
587
650
  beta2 = beta_debias(beta2, step)
588
651
 
@@ -595,21 +658,17 @@ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2
595
658
  return denom
596
659
 
597
660
 
598
- def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
599
- grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
600
- if isinstance(beta1, float):
601
- beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
602
- if isinstance(beta2, float):
603
- beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
604
- if isinstance(step, int):
605
- step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
661
+ def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
662
+ beta1: float, beta2: float, step: int):
663
+ exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
664
+ grad), list_guard(grad_projected)
665
+ beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
606
666
  denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
607
667
  return denom
608
668
 
609
669
 
610
- # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
611
- @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
612
- def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
670
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
671
+ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
613
672
  """Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
614
673
  # create a random 16 bit integer
615
674
  result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
@@ -624,8 +683,8 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
624
683
  target.copy_(result.view(dtype=torch.float32))
625
684
 
626
685
 
627
- def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
628
- if not torch.compiler.is_compiling() and target.data_ptr() == source.data_ptr():
686
+ def copy_stochastic_(target: Tensor, source: Tensor):
687
+ if not is_compiling() and target.data_ptr() == source.data_ptr():
629
688
  return
630
689
  if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
631
690
  set_(target, source)
@@ -633,26 +692,31 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
633
692
 
634
693
 
635
694
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
636
- def _compilable_update_(p, u, decay, add_fn, lr):
695
+ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
696
+ g: List[Optional[Tensor]]):
637
697
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
638
698
  p32, u32 = [list(map(promote, x)) for x in [p, u]]
639
699
 
640
700
  if decay > 0:
641
701
  torch._foreach_mul_(p32, 1 - decay * lr)
642
702
 
643
- for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
644
- if add_fn is None:
645
- p32_.add_(u32_, alpha=lr)
646
- else:
647
- add_fn(p32_, u32_, lr)
703
+ for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
704
+ if caution:
705
+ _compilable_cautioning_(promote(g_), u32_)
706
+ add_fn(p32_, u32_, lr)
648
707
 
649
708
  copy_stochastic_list_(p, p32)
650
709
 
651
710
 
652
- def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
653
- add_fn: callable = None):
654
- lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
655
- _compilable_update_(param, update, decay, add_fn, lr_tensor)
711
+ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
712
+ caution: bool = False, grad: List[Tensor] = None):
713
+ param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
714
+ lr = scalar_guard(lr, param[0])
715
+ if not caution:
716
+ grad = [None] * len(param)
717
+ if add_fn is None:
718
+ add_fn = stochastic_add_
719
+ _compilable_update_(param, update, decay, add_fn, lr, caution, grad)
656
720
 
657
721
 
658
722
  def precond_schedule(step, precond_scheduler, rng):
@@ -822,14 +886,14 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
822
886
 
823
887
 
824
888
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
825
- def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
889
+ def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
826
890
  """Precondition gradient G with preconditioner Q."""
827
- md = min_dtype(Q)
828
- out = torch.einsum(exprs[-1], *[q.conj().to(md) for q in Q], *[q.to(md) for q in Q], G.to(md))
891
+ md = min_dtype(preconds)
892
+ out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
829
893
  if inplace:
830
- set_(G, out)
831
- return G
832
- return out.to(G.dtype)
894
+ set_(grad, out)
895
+ return grad
896
+ return out.to(grad.dtype)
833
897
 
834
898
 
835
899
  def norm_clip_(x, scale=None):
@@ -892,7 +956,7 @@ def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
892
956
 
893
957
 
894
958
  @decorator
895
- def triu_to_line(Q_list: List[torch.Tensor]):
959
+ def triu_to_line(Q_list: List[Tensor]):
896
960
  out = []
897
961
  for q in Q_list:
898
962
  if q.dim() < 2:
@@ -909,7 +973,7 @@ def _triu_shape(numel):
909
973
 
910
974
 
911
975
  @decorator
912
- def line_to_triu(Q_list: List[Tuple[Optional[List[int]], torch.Tensor]]):
976
+ def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
913
977
  new = []
914
978
  for shape, q in Q_list:
915
979
  if shape is not None:
@@ -965,18 +1029,45 @@ class PSGDBase(StatefulOptimizer):
965
1029
  psgd_balance_Q(q)
966
1030
 
967
1031
 
1032
+ # TODO: Figure out why this sometimes crashes
968
1033
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
969
- def _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn):
1034
+ def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
1035
+ clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
970
1036
  md = min_dtype(cached_q + [ea])
971
1037
  new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
972
- update_param_([param], clip_fn([new]), lr, weight_decay)
1038
+ update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
1039
+
1040
+
1041
+ def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
1042
+ clip_fn, caution, grad):
1043
+ lr = scalar_guard(lr, param)
1044
+ _compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
1045
+
1046
+
1047
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1048
+ def _compilable_mars_correction_(g: Tensor, old_g: Tensor, a: Tensor):
1049
+ g_copy = [g_.clone() for g_ in g]
1050
+ _compilable_stochastic_lerp_(g, old_g, a)
1051
+ copy_stochastic_list_(old_g, g_copy)
1052
+
1053
+
1054
+ def mars_correction(g, old_g, beta1, gamma):
1055
+ a = -gamma * beta1 / (1 - beta1)
1056
+ g, old_g = list_guard(g), list_guard(old_g)
1057
+ a = scalar_guard(a, g[0])
1058
+ _compilable_mars_correction_(g, old_g, a)
973
1059
 
974
1060
 
975
- def precond_grad_cached_(cached_q: List[torch.Tensor], ea: torch.Tensor, expr: str, param: torch.Tensor, lr: float,
976
- weight_decay: float, clip_fn):
977
- if isinstance(lr, float):
978
- lr = torch.empty((), dtype=torch.float32, device=param.device).fill_(lr)
979
- _compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn)
1061
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
1062
+ def _compilable_cautioning_(g: Tensor, update: Tensor):
1063
+ mask = (g * update) > 0
1064
+ update.masked_fill_(~mask, 0)
1065
+ scale = mask.numel() / mask.sum().clamp(min=1)
1066
+ update.mul_(scale)
1067
+
1068
+
1069
+ def caution(g, update):
1070
+ _compilable_cautioning_(g, update)
980
1071
 
981
1072
 
982
1073
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
@@ -1013,29 +1104,3 @@ def merge_group(group, *tensors):
1013
1104
  append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
1014
1105
  'max_precond_dim'], group.get('split', False)))
1015
1106
  return out
1016
-
1017
-
1018
- def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
1019
- for p in group["params"]:
1020
- if skip_none and p.grad is None:
1021
- continue
1022
-
1023
- if p.grad is None:
1024
- grad = None
1025
- else:
1026
- if should_promote:
1027
- grad = promote(p.grad)
1028
- else:
1029
- grad = p.grad
1030
- p.grad = None
1031
-
1032
- p_views = merge_group(group, p)
1033
- if grad is not None:
1034
- grad = merge_group(group, grad)
1035
- if isinstance(p_views, torch.Tensor):
1036
- yield p_views, grad
1037
- continue
1038
- if grad is None:
1039
- yield from zip(p_views, [None] * len(p_views))
1040
- continue
1041
- yield from zip(p_views, grad)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.21.8
3
+ Version: 0.23.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -0,0 +1,24 @@
1
+ heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=n3wIOhrop0Ls4MZ0kXpwGuImp1jzPs6VGdxIlPyoYdQ,6827
3
+ heavyball/cached_psgd_kron.py,sha256=KCLsfvj9qh_2FNwRTdWM3zjnt2oGHfsf4Y341rPcceI,6778
4
+ heavyball/delayed_psgd.py,sha256=z_Y1eYr2upVt_FsyCIv91yTFJY6yqvHsI8S2mOpqdv8,6334
5
+ heavyball/foreach_adamw.py,sha256=uawSbGGUD2E1RtcwspP83yQNElERdGX-diqCI5e8FqE,2825
6
+ heavyball/foreach_adopt.py,sha256=DFEaPswVzdHcbxC-mirsf_okM_HR6r34PDUTty5CrUE,3547
7
+ heavyball/foreach_laprop.py,sha256=J4Vms0nAOMh3GQtAOPyrYOe5WtpzokVv25b9oDnwc2A,2833
8
+ heavyball/foreach_sfadamw.py,sha256=HWbLekY5BloHDIgrN2J0a7IolZCt8Ah2xkLAU_-5oSc,3079
9
+ heavyball/foreach_soap.py,sha256=7B_dP2Hm_xqwpBQiPYkv_c6eoRnU1dV2VZfvSoa4uJ8,4729
10
+ heavyball/p_adam.py,sha256=8BlZ6YoaDXawMiRbCxo0Kd5_0-pAn0MQIhL0LHNaRBs,6315
11
+ heavyball/palm_foreach_sfadamw.py,sha256=E8raxrBIkSmTEGFzwnfWxKwDJjBQE2vdsmyqfc8aL_A,3375
12
+ heavyball/palm_foreach_soap.py,sha256=IknGm_CzrqDIFEoCkejxjoZ4sfIy6RSoInqlMUOYLB4,6156
13
+ heavyball/precond_schedule_foreach_soap.py,sha256=bJ2ifPFa8zEP9GO8eBpqZzsmP7p_iQkkCkllNeEMHPU,4892
14
+ heavyball/precond_schedule_palm_foreach_soap.py,sha256=4dT9f134-Faq2KuCMCHzMtrkMO-es5p_DYS1of5yF-s,6428
15
+ heavyball/precond_schedule_sfpsoap.py,sha256=FOR-axwlkSN7IHZWYYUVFfjSFCLxc_NdiTlb-n5gmgs,7530
16
+ heavyball/psgd_kron.py,sha256=4eiGPXAFjvGIXLdiai1UJfAvTozAV1TXaE9UGkE4BLc,6051
17
+ heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
18
+ heavyball/schedule_free_palm_foreach_soap.py,sha256=0WT_gvTKymqLQzYT6ewDgCmpDq-HgMAewipw1QvyQYA,7267
19
+ heavyball/utils.py,sha256=AZlY8dfM0d-C0FXBCJHTJOOoi3RjkMJ-XhU25aBN878,39521
20
+ heavyball-0.23.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.23.0.dist-info/METADATA,sha256=3IBUhXA7VJT9GQh460OznCAcIqCG_Mv5Q7HZO8FQ40w,11926
22
+ heavyball-0.23.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.23.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.23.0.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
2
- heavyball/cached_delayed_psgd_kron.py,sha256=Nyxl-G-o6greKwDN-vLiw5W02GXO2LRvknc0OzvzFnE,6674
3
- heavyball/cached_psgd_kron.py,sha256=HzD6se0AYb-W5hpydUxcR9uqrpe_54PBwgL1VWX3DHU,6592
4
- heavyball/delayed_psgd.py,sha256=m4c-OvcLMrRxSAPYs2l6Up21uCyF2kvHvpcnfe3nzGs,6212
5
- heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
6
- heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
7
- heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
8
- heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
9
- heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
10
- heavyball/p_adam.py,sha256=Xyxsavwtw-t0OyTHitYQXZSmF9UJlMDzDAURge-MbbQ,6047
11
- heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
12
- heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
13
- heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
14
- heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
15
- heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
16
- heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
17
- heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
18
- heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
19
- heavyball/utils.py,sha256=xTDZEt2_DM57EYnJkRq7d7scTnro4eKPdMtEwPdLy-c,37218
20
- heavyball-0.21.8.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.21.8.dist-info/METADATA,sha256=nLyxHlENmhAGyU9GManYKKJJTykhsAMt7hkJNXPu_YY,11926
22
- heavyball-0.21.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.21.8.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.21.8.dist-info/RECORD,,