heavyball 0.21.8__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +6 -5
- heavyball/cached_delayed_psgd_kron.py +6 -5
- heavyball/cached_psgd_kron.py +7 -5
- heavyball/delayed_psgd.py +14 -11
- heavyball/foreach_adamw.py +14 -7
- heavyball/foreach_adopt.py +11 -6
- heavyball/foreach_laprop.py +12 -6
- heavyball/foreach_sfadamw.py +10 -3
- heavyball/foreach_soap.py +10 -8
- heavyball/p_adam.py +11 -9
- heavyball/palm_foreach_sfadamw.py +11 -3
- heavyball/palm_foreach_soap.py +8 -9
- heavyball/precond_schedule_foreach_soap.py +10 -8
- heavyball/precond_schedule_palm_foreach_soap.py +9 -9
- heavyball/precond_schedule_sfpsoap.py +10 -5
- heavyball/psgd_kron.py +9 -6
- heavyball/pure_psgd.py +11 -7
- heavyball/schedule_free_palm_foreach_soap.py +13 -5
- heavyball/utils.py +171 -106
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/METADATA +2 -2
- heavyball-0.23.0.dist-info/RECORD +24 -0
- heavyball-0.21.8.dist-info/RECORD +0 -24
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/LICENSE +0 -0
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/WHEEL +0 -0
- {heavyball-0.21.8.dist-info → heavyball-0.23.0.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple, Callable, Union
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
10
|
+
from torch import Tensor
|
10
11
|
from torch.backends import cudnn, opt_einsum
|
11
12
|
from torch.utils._pytree import tree_map
|
12
13
|
|
@@ -39,15 +40,14 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
39
40
|
|
40
41
|
|
41
42
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
42
|
-
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
-
p32 = promote(p)
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
copy_stochastic_(z, z32)
|
43
|
+
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor, beta1: Tensor):
|
44
|
+
p32, z32, g32 = [promote(x) for x in (p, z, grad)]
|
45
|
+
for p_, z_, g_ in zip(p32, z32, g32):
|
46
|
+
p_.lerp_(z_, ckp1)
|
47
|
+
p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
48
|
+
z_.add(g_, alpha=-lr)
|
49
|
+
copy_stochastic_list_(p, p32)
|
50
|
+
copy_stochastic_list_(z, z32)
|
51
51
|
|
52
52
|
|
53
53
|
def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
@@ -61,8 +61,8 @@ def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
|
61
61
|
return ckp1, weight_sum
|
62
62
|
|
63
63
|
|
64
|
-
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[
|
65
|
-
z: List[
|
64
|
+
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
65
|
+
z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
|
66
66
|
weight = lr ** weight_lr_power * max(step, 1) ** r
|
67
67
|
weight_sum = weight_sum + weight
|
68
68
|
|
@@ -73,10 +73,8 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
73
73
|
|
74
74
|
# These operations update y in-place,
|
75
75
|
# without computing x explicitly.
|
76
|
-
|
77
|
-
|
78
|
-
for p, z_, g in zip(parameters, z, grad):
|
79
|
-
_compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
|
76
|
+
lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
|
77
|
+
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
|
80
78
|
return weight_sum
|
81
79
|
|
82
80
|
|
@@ -142,19 +140,25 @@ def beta_debias(beta, step):
|
|
142
140
|
|
143
141
|
|
144
142
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
145
|
-
def
|
146
|
-
if isinstance(state, torch.Tensor):
|
147
|
-
state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
148
|
-
return torch.sqrt(state, out=out).clamp_(min=eps)
|
149
|
-
|
143
|
+
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]):
|
150
144
|
torch._foreach_mul_(state, beta2)
|
151
145
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
|
152
146
|
denom = torch._foreach_sqrt(state)
|
153
|
-
|
154
|
-
|
147
|
+
[denom.clamp_(min=eps) for denom in denom]
|
148
|
+
if out[0] is None:
|
149
|
+
return denom
|
155
150
|
|
151
|
+
copy_stochastic_list_(out, denom)
|
152
|
+
return out
|
156
153
|
|
157
|
-
|
154
|
+
|
155
|
+
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
156
|
+
state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
|
157
|
+
beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
|
158
|
+
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
159
|
+
|
160
|
+
|
161
|
+
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
158
162
|
minimum: float = 1e-3, eps: float = 1e-8):
|
159
163
|
if clip_val <= 0:
|
160
164
|
return
|
@@ -168,12 +172,19 @@ def adaptive_gradient_clipping_(parameters: List[torch.Tensor], gradients: List[
|
|
168
172
|
torch._foreach_mul_(gradients, p_norm)
|
169
173
|
|
170
174
|
|
171
|
-
def
|
172
|
-
|
175
|
+
def is_compiling():
|
176
|
+
try:
|
177
|
+
return torch.compiler.is_compiling()
|
178
|
+
except AttributeError:
|
179
|
+
return True
|
180
|
+
|
181
|
+
|
182
|
+
def set_(dst: Tensor, src: Tensor):
|
183
|
+
if not is_compiling() and src.data_ptr() == dst.data_ptr():
|
173
184
|
return
|
174
185
|
if src.shape != dst.shape:
|
175
186
|
src = src.reshape_as(dst)
|
176
|
-
if not
|
187
|
+
if not is_compiling() and src.is_contiguous() and dst.is_contiguous() and src.dtype == dst.dtype:
|
177
188
|
dst.set_(src)
|
178
189
|
else:
|
179
190
|
dst.copy_(src)
|
@@ -329,7 +340,7 @@ def get_orthogonal_matrix(mat):
|
|
329
340
|
|
330
341
|
|
331
342
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
332
|
-
def _compilable_stochastic_lerp_(x: List[
|
343
|
+
def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
333
344
|
for x_, y_ in zip(x, y):
|
334
345
|
x32 = promote(x_)
|
335
346
|
y32 = promote(y_)
|
@@ -337,14 +348,28 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
|
|
337
348
|
copy_stochastic_(x_, x32)
|
338
349
|
|
339
350
|
|
340
|
-
def stochastic_lerp_(x: List[
|
341
|
-
|
342
|
-
|
351
|
+
def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
352
|
+
x, y = list_guard(x), list_guard(y)
|
353
|
+
a = scalar_guard(a, x[0])
|
343
354
|
_compilable_stochastic_lerp_(x, y, a)
|
344
355
|
|
345
356
|
|
357
|
+
def list_guard(x):
|
358
|
+
if isinstance(x, (list, tuple)):
|
359
|
+
return x
|
360
|
+
return [x]
|
361
|
+
|
362
|
+
|
363
|
+
def scalar_guard(x, ref):
|
364
|
+
if isinstance(x, float):
|
365
|
+
return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
|
366
|
+
if isinstance(x, int):
|
367
|
+
return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
|
368
|
+
return x
|
369
|
+
|
370
|
+
|
346
371
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
347
|
-
def _compilable_stochastic_add_(x: List[
|
372
|
+
def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
348
373
|
for x_, y_ in zip(x, y):
|
349
374
|
x32 = promote(x_)
|
350
375
|
y32 = promote(y_)
|
@@ -352,9 +377,9 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
|
|
352
377
|
copy_stochastic_(x_, x32)
|
353
378
|
|
354
379
|
|
355
|
-
def stochastic_add_(x: List[
|
356
|
-
|
357
|
-
|
380
|
+
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
381
|
+
x, y = list_guard(x), list_guard(y)
|
382
|
+
alpha = scalar_guard(alpha, x[0])
|
358
383
|
_compilable_stochastic_add_(x, y, alpha)
|
359
384
|
|
360
385
|
|
@@ -376,12 +401,12 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
376
401
|
def promote(x):
|
377
402
|
if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
|
378
403
|
return torch.float32
|
379
|
-
if isinstance(x,
|
404
|
+
if isinstance(x, Tensor) and x.dtype in (torch.bfloat16, torch.float16):
|
380
405
|
return x.float()
|
381
406
|
return x
|
382
407
|
|
383
408
|
|
384
|
-
def min_dtype(xs: List[
|
409
|
+
def min_dtype(xs: List[Tensor]):
|
385
410
|
dtypes = [x.dtype for x in xs]
|
386
411
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
387
412
|
if all(x in (d, torch.float32, torch.float64) for x in dtypes):
|
@@ -447,7 +472,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
447
472
|
self.fake_groups = {}
|
448
473
|
self.use_ema = use_ema
|
449
474
|
|
450
|
-
def key(self, param:
|
475
|
+
def key(self, param: Tensor):
|
451
476
|
return (param.data_ptr(), tuple(param.shape))
|
452
477
|
|
453
478
|
def get_groups(self, group):
|
@@ -460,19 +485,56 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
460
485
|
|
461
486
|
return [self.fake_groups[self.key(p)] for p in group['params']]
|
462
487
|
|
463
|
-
def state_(self, arg:
|
488
|
+
def state_(self, arg: Tensor):
|
464
489
|
return self.state[self.key(arg)]
|
465
490
|
|
491
|
+
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
492
|
+
for p, g in zip(p_list, g_list):
|
493
|
+
state = self.state_(p)
|
494
|
+
if 'mars_old_grad' not in state:
|
495
|
+
state['mars_old_grad'] = torch.zeros_like(g)
|
496
|
+
old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
|
497
|
+
mars_correction(g_list, old_gs, mars_gamma, beta)
|
498
|
+
|
499
|
+
def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
|
500
|
+
beta1: float = -1.0):
|
501
|
+
for p in group["params"]:
|
502
|
+
if skip_none and p.grad is None:
|
503
|
+
continue
|
504
|
+
|
505
|
+
if p.grad is None:
|
506
|
+
grad = None
|
507
|
+
else:
|
508
|
+
if should_promote:
|
509
|
+
grad = promote(p.grad)
|
510
|
+
else:
|
511
|
+
grad = p.grad
|
512
|
+
if beta1 >= 0 and group.get('mars', False):
|
513
|
+
self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
|
514
|
+
|
515
|
+
p.grad = None
|
516
|
+
|
517
|
+
p_views = merge_group(group, p)
|
518
|
+
if grad is not None:
|
519
|
+
grad = merge_group(group, grad)
|
520
|
+
if isinstance(p_views, Tensor):
|
521
|
+
yield p_views, grad
|
522
|
+
continue
|
523
|
+
if grad is None:
|
524
|
+
yield from zip(p_views, [None] * len(p_views))
|
525
|
+
continue
|
526
|
+
yield from zip(p_views, grad)
|
527
|
+
|
466
528
|
def state_size(self) -> int:
|
467
529
|
total_bytes = 0
|
468
530
|
|
469
531
|
def _add(x):
|
470
532
|
nonlocal total_bytes
|
471
|
-
if isinstance(x,
|
533
|
+
if isinstance(x, Tensor):
|
472
534
|
total_bytes += x.numel() * x.element_size()
|
473
535
|
|
474
536
|
for group in self.param_groups:
|
475
|
-
for p, _ in split_p_and_g_in_group(group, skip_none=False):
|
537
|
+
for p, _ in self.split_p_and_g_in_group(group, skip_none=False):
|
476
538
|
tree_map(_add, self.state_(p))
|
477
539
|
return total_bytes
|
478
540
|
|
@@ -576,13 +638,14 @@ class ScheduleFree(StatefulOptimizer):
|
|
576
638
|
raise NotImplementedError
|
577
639
|
|
578
640
|
|
579
|
-
def copy_stochastic_list_(target: List[
|
641
|
+
def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
580
642
|
for t, s in zip(target, source):
|
581
643
|
copy_stochastic_(t, s)
|
582
644
|
|
583
645
|
|
584
646
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
585
|
-
def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad
|
647
|
+
def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
648
|
+
grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
|
586
649
|
beta1 = beta_debias(beta1, step)
|
587
650
|
beta2 = beta_debias(beta2, step)
|
588
651
|
|
@@ -595,21 +658,17 @@ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2
|
|
595
658
|
return denom
|
596
659
|
|
597
660
|
|
598
|
-
def exp_avg_(exp_avg: List[
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
|
604
|
-
if isinstance(step, int):
|
605
|
-
step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
|
661
|
+
def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
|
662
|
+
beta1: float, beta2: float, step: int):
|
663
|
+
exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
|
664
|
+
grad), list_guard(grad_projected)
|
665
|
+
beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
|
606
666
|
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
607
667
|
return denom
|
608
668
|
|
609
669
|
|
610
|
-
|
611
|
-
|
612
|
-
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
670
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
671
|
+
def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
613
672
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
614
673
|
# create a random 16 bit integer
|
615
674
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
@@ -624,8 +683,8 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
624
683
|
target.copy_(result.view(dtype=torch.float32))
|
625
684
|
|
626
685
|
|
627
|
-
def copy_stochastic_(target:
|
628
|
-
if not
|
686
|
+
def copy_stochastic_(target: Tensor, source: Tensor):
|
687
|
+
if not is_compiling() and target.data_ptr() == source.data_ptr():
|
629
688
|
return
|
630
689
|
if target.dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
|
631
690
|
set_(target, source)
|
@@ -633,26 +692,31 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
633
692
|
|
634
693
|
|
635
694
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
636
|
-
def _compilable_update_(p, u, decay, add_fn, lr
|
695
|
+
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
|
696
|
+
g: List[Optional[Tensor]]):
|
637
697
|
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
638
698
|
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
639
699
|
|
640
700
|
if decay > 0:
|
641
701
|
torch._foreach_mul_(p32, 1 - decay * lr)
|
642
702
|
|
643
|
-
for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
|
644
|
-
if
|
645
|
-
|
646
|
-
|
647
|
-
add_fn(p32_, u32_, lr)
|
703
|
+
for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
|
704
|
+
if caution:
|
705
|
+
_compilable_cautioning_(promote(g_), u32_)
|
706
|
+
add_fn(p32_, u32_, lr)
|
648
707
|
|
649
708
|
copy_stochastic_list_(p, p32)
|
650
709
|
|
651
710
|
|
652
|
-
def update_param_(param: List[
|
653
|
-
|
654
|
-
|
655
|
-
|
711
|
+
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
|
712
|
+
caution: bool = False, grad: List[Tensor] = None):
|
713
|
+
param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
|
714
|
+
lr = scalar_guard(lr, param[0])
|
715
|
+
if not caution:
|
716
|
+
grad = [None] * len(param)
|
717
|
+
if add_fn is None:
|
718
|
+
add_fn = stochastic_add_
|
719
|
+
_compilable_update_(param, update, decay, add_fn, lr, caution, grad)
|
656
720
|
|
657
721
|
|
658
722
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -822,14 +886,14 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
822
886
|
|
823
887
|
|
824
888
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
825
|
-
def psgd_precond_grad(
|
889
|
+
def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
|
826
890
|
"""Precondition gradient G with preconditioner Q."""
|
827
|
-
md = min_dtype(
|
828
|
-
out = torch.einsum(exprs
|
891
|
+
md = min_dtype(preconds)
|
892
|
+
out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
|
829
893
|
if inplace:
|
830
|
-
set_(
|
831
|
-
return
|
832
|
-
return out.to(
|
894
|
+
set_(grad, out)
|
895
|
+
return grad
|
896
|
+
return out.to(grad.dtype)
|
833
897
|
|
834
898
|
|
835
899
|
def norm_clip_(x, scale=None):
|
@@ -892,7 +956,7 @@ def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
|
892
956
|
|
893
957
|
|
894
958
|
@decorator
|
895
|
-
def triu_to_line(Q_list: List[
|
959
|
+
def triu_to_line(Q_list: List[Tensor]):
|
896
960
|
out = []
|
897
961
|
for q in Q_list:
|
898
962
|
if q.dim() < 2:
|
@@ -909,7 +973,7 @@ def _triu_shape(numel):
|
|
909
973
|
|
910
974
|
|
911
975
|
@decorator
|
912
|
-
def line_to_triu(Q_list: List[Tuple[Optional[List[int]],
|
976
|
+
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
|
913
977
|
new = []
|
914
978
|
for shape, q in Q_list:
|
915
979
|
if shape is not None:
|
@@ -965,18 +1029,45 @@ class PSGDBase(StatefulOptimizer):
|
|
965
1029
|
psgd_balance_Q(q)
|
966
1030
|
|
967
1031
|
|
1032
|
+
# TODO: Figure out why this sometimes crashes
|
968
1033
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
969
|
-
def _compilable_precond_grad_cached_(
|
1034
|
+
def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
|
1035
|
+
clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
|
970
1036
|
md = min_dtype(cached_q + [ea])
|
971
1037
|
new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
972
|
-
update_param_([param], clip_fn([new]), lr, weight_decay)
|
1038
|
+
update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
|
1039
|
+
|
1040
|
+
|
1041
|
+
def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
|
1042
|
+
clip_fn, caution, grad):
|
1043
|
+
lr = scalar_guard(lr, param)
|
1044
|
+
_compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
|
1045
|
+
|
1046
|
+
|
1047
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1048
|
+
def _compilable_mars_correction_(g: Tensor, old_g: Tensor, a: Tensor):
|
1049
|
+
g_copy = [g_.clone() for g_ in g]
|
1050
|
+
_compilable_stochastic_lerp_(g, old_g, a)
|
1051
|
+
copy_stochastic_list_(old_g, g_copy)
|
1052
|
+
|
1053
|
+
|
1054
|
+
def mars_correction(g, old_g, beta1, gamma):
|
1055
|
+
a = -gamma * beta1 / (1 - beta1)
|
1056
|
+
g, old_g = list_guard(g), list_guard(old_g)
|
1057
|
+
a = scalar_guard(a, g[0])
|
1058
|
+
_compilable_mars_correction_(g, old_g, a)
|
973
1059
|
|
974
1060
|
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
1061
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1062
|
+
def _compilable_cautioning_(g: Tensor, update: Tensor):
|
1063
|
+
mask = (g * update) > 0
|
1064
|
+
update.masked_fill_(~mask, 0)
|
1065
|
+
scale = mask.numel() / mask.sum().clamp(min=1)
|
1066
|
+
update.mul_(scale)
|
1067
|
+
|
1068
|
+
|
1069
|
+
def caution(g, update):
|
1070
|
+
_compilable_cautioning_(g, update)
|
980
1071
|
|
981
1072
|
|
982
1073
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -1013,29 +1104,3 @@ def merge_group(group, *tensors):
|
|
1013
1104
|
append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
|
1014
1105
|
'max_precond_dim'], group.get('split', False)))
|
1015
1106
|
return out
|
1016
|
-
|
1017
|
-
|
1018
|
-
def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
|
1019
|
-
for p in group["params"]:
|
1020
|
-
if skip_none and p.grad is None:
|
1021
|
-
continue
|
1022
|
-
|
1023
|
-
if p.grad is None:
|
1024
|
-
grad = None
|
1025
|
-
else:
|
1026
|
-
if should_promote:
|
1027
|
-
grad = promote(p.grad)
|
1028
|
-
else:
|
1029
|
-
grad = p.grad
|
1030
|
-
p.grad = None
|
1031
|
-
|
1032
|
-
p_views = merge_group(group, p)
|
1033
|
-
if grad is not None:
|
1034
|
-
grad = merge_group(group, grad)
|
1035
|
-
if isinstance(p_views, torch.Tensor):
|
1036
|
-
yield p_views, grad
|
1037
|
-
continue
|
1038
|
-
if grad is None:
|
1039
|
-
yield from zip(p_views, [None] * len(p_views))
|
1040
|
-
continue
|
1041
|
-
yield from zip(p_views, grad)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.23.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-
|
35
|
+
Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -0,0 +1,24 @@
|
|
1
|
+
heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
|
2
|
+
heavyball/cached_delayed_psgd_kron.py,sha256=n3wIOhrop0Ls4MZ0kXpwGuImp1jzPs6VGdxIlPyoYdQ,6827
|
3
|
+
heavyball/cached_psgd_kron.py,sha256=KCLsfvj9qh_2FNwRTdWM3zjnt2oGHfsf4Y341rPcceI,6778
|
4
|
+
heavyball/delayed_psgd.py,sha256=z_Y1eYr2upVt_FsyCIv91yTFJY6yqvHsI8S2mOpqdv8,6334
|
5
|
+
heavyball/foreach_adamw.py,sha256=uawSbGGUD2E1RtcwspP83yQNElERdGX-diqCI5e8FqE,2825
|
6
|
+
heavyball/foreach_adopt.py,sha256=DFEaPswVzdHcbxC-mirsf_okM_HR6r34PDUTty5CrUE,3547
|
7
|
+
heavyball/foreach_laprop.py,sha256=J4Vms0nAOMh3GQtAOPyrYOe5WtpzokVv25b9oDnwc2A,2833
|
8
|
+
heavyball/foreach_sfadamw.py,sha256=HWbLekY5BloHDIgrN2J0a7IolZCt8Ah2xkLAU_-5oSc,3079
|
9
|
+
heavyball/foreach_soap.py,sha256=7B_dP2Hm_xqwpBQiPYkv_c6eoRnU1dV2VZfvSoa4uJ8,4729
|
10
|
+
heavyball/p_adam.py,sha256=8BlZ6YoaDXawMiRbCxo0Kd5_0-pAn0MQIhL0LHNaRBs,6315
|
11
|
+
heavyball/palm_foreach_sfadamw.py,sha256=E8raxrBIkSmTEGFzwnfWxKwDJjBQE2vdsmyqfc8aL_A,3375
|
12
|
+
heavyball/palm_foreach_soap.py,sha256=IknGm_CzrqDIFEoCkejxjoZ4sfIy6RSoInqlMUOYLB4,6156
|
13
|
+
heavyball/precond_schedule_foreach_soap.py,sha256=bJ2ifPFa8zEP9GO8eBpqZzsmP7p_iQkkCkllNeEMHPU,4892
|
14
|
+
heavyball/precond_schedule_palm_foreach_soap.py,sha256=4dT9f134-Faq2KuCMCHzMtrkMO-es5p_DYS1of5yF-s,6428
|
15
|
+
heavyball/precond_schedule_sfpsoap.py,sha256=FOR-axwlkSN7IHZWYYUVFfjSFCLxc_NdiTlb-n5gmgs,7530
|
16
|
+
heavyball/psgd_kron.py,sha256=4eiGPXAFjvGIXLdiai1UJfAvTozAV1TXaE9UGkE4BLc,6051
|
17
|
+
heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
|
18
|
+
heavyball/schedule_free_palm_foreach_soap.py,sha256=0WT_gvTKymqLQzYT6ewDgCmpDq-HgMAewipw1QvyQYA,7267
|
19
|
+
heavyball/utils.py,sha256=AZlY8dfM0d-C0FXBCJHTJOOoi3RjkMJ-XhU25aBN878,39521
|
20
|
+
heavyball-0.23.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.23.0.dist-info/METADATA,sha256=3IBUhXA7VJT9GQh460OznCAcIqCG_Mv5Q7HZO8FQ40w,11926
|
22
|
+
heavyball-0.23.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.23.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.23.0.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
2
|
-
heavyball/cached_delayed_psgd_kron.py,sha256=Nyxl-G-o6greKwDN-vLiw5W02GXO2LRvknc0OzvzFnE,6674
|
3
|
-
heavyball/cached_psgd_kron.py,sha256=HzD6se0AYb-W5hpydUxcR9uqrpe_54PBwgL1VWX3DHU,6592
|
4
|
-
heavyball/delayed_psgd.py,sha256=m4c-OvcLMrRxSAPYs2l6Up21uCyF2kvHvpcnfe3nzGs,6212
|
5
|
-
heavyball/foreach_adamw.py,sha256=Rb5U80cgUcEqlEbUU250UTWdoqA7nyiqkV5w1U4bWX4,2445
|
6
|
-
heavyball/foreach_adopt.py,sha256=ecdi1fKg9i087OGjtKWVbE_DD6Yf4pvpzv4ELCcusvQ,3211
|
7
|
-
heavyball/foreach_laprop.py,sha256=vi6C_gfjXxw5uN0KHgzxI9itUI1dcgOf3ufoO_VVMp0,2471
|
8
|
-
heavyball/foreach_sfadamw.py,sha256=rLZORmCIMu9G09FdDgMSiI6pNq34IVoxsPVWtmeDdbQ,2753
|
9
|
-
heavyball/foreach_soap.py,sha256=4mWSMWYTdjgiXiboI5DwdigecruDtNGKylGAFAVhCRA,4562
|
10
|
-
heavyball/p_adam.py,sha256=Xyxsavwtw-t0OyTHitYQXZSmF9UJlMDzDAURge-MbbQ,6047
|
11
|
-
heavyball/palm_foreach_sfadamw.py,sha256=JbNrcoquBGGUI5XNMFouDjpNurVHUW9DbX1A3tSrtno,3025
|
12
|
-
heavyball/palm_foreach_soap.py,sha256=GzAwM8kOt1X0QCmUZDTdHwPxbJwjH8ic43dyAK5BYCA,6015
|
13
|
-
heavyball/precond_schedule_foreach_soap.py,sha256=HcObXLfSNN_lKNb4nmC6tkdHcqDIMNX6hILpHKScqLc,4744
|
14
|
-
heavyball/precond_schedule_palm_foreach_soap.py,sha256=xZ7CJvIfdu2RNAZt2g1S7Xb0Jyy1hNC4MowOFU3nWkk,6283
|
15
|
-
heavyball/precond_schedule_sfpsoap.py,sha256=PNneiOkrRyV1yIZn91lPmYofd1_OiLqJTDy75RLpXJk,7273
|
16
|
-
heavyball/psgd_kron.py,sha256=RSLJi5FnSmYjvYBufNDClnYRm-eN_Kpa1Ar2tNP6-X0,5824
|
17
|
-
heavyball/pure_psgd.py,sha256=LZK0qmvZkBF8g00evaVLtW-sIUJmdoxag1K7O26AqEo,4820
|
18
|
-
heavyball/schedule_free_palm_foreach_soap.py,sha256=_AqrnChY6iDlQUkF2YUxS7eLjSWCIuvEUOHvMHVM1yY,6873
|
19
|
-
heavyball/utils.py,sha256=xTDZEt2_DM57EYnJkRq7d7scTnro4eKPdMtEwPdLy-c,37218
|
20
|
-
heavyball-0.21.8.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
-
heavyball-0.21.8.dist-info/METADATA,sha256=nLyxHlENmhAGyU9GManYKKJJTykhsAMt7hkJNXPu_YY,11926
|
22
|
-
heavyball-0.21.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
-
heavyball-0.21.8.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
-
heavyball-0.21.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|