heavyball 0.18.4__py3-none-any.whl → 0.18.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/foreach_adamw.py +2 -2
- heavyball/utils.py +35 -18
- {heavyball-0.18.4.dist-info → heavyball-0.18.5.dist-info}/METADATA +1 -1
- {heavyball-0.18.4.dist-info → heavyball-0.18.5.dist-info}/RECORD +7 -7
- {heavyball-0.18.4.dist-info → heavyball-0.18.5.dist-info}/LICENSE +0 -0
- {heavyball-0.18.4.dist-info → heavyball-0.18.5.dist-info}/WHEEL +0 -0
- {heavyball-0.18.4.dist-info → heavyball-0.18.5.dist-info}/top_level.txt +0 -0
heavyball/foreach_adamw.py
CHANGED
@@ -34,9 +34,9 @@ class ForeachAdamW(StatefulOptimizer):
|
|
34
34
|
|
35
35
|
# Decay the first and second moment running average coefficient
|
36
36
|
torch._foreach_lerp_(exp_avg, grad, 1 - beta_debias(group['betas'][0], k + 1))
|
37
|
-
denom = exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps)
|
37
|
+
denom = list(exp_avg_sq_(exp_avg_sq, grad, beta_debias(group['betas'][1], k + 1), eps))
|
38
38
|
|
39
39
|
# Normalize grad in-place for memory efficiency
|
40
40
|
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
41
|
-
update_param_(y, exp_avg, lr, decay, lambda p, e, l:
|
41
|
+
update_param_(y, exp_avg, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l))
|
42
42
|
group['k'] = k + 1
|
heavyball/utils.py
CHANGED
@@ -38,6 +38,18 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
38
38
|
return lr * step / warmup_steps
|
39
39
|
|
40
40
|
|
41
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
42
|
+
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
+
p32 = p.float()
|
44
|
+
z32 = z.float()
|
45
|
+
p32.lerp_(end=z32, weight=1 - ckp1)
|
46
|
+
p32.add_(grad, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
47
|
+
_compilable_copy_stochastic_(p, p32)
|
48
|
+
|
49
|
+
z32.add_(grad, alpha=-lr)
|
50
|
+
_compilable_copy_stochastic_(z, z32)
|
51
|
+
|
52
|
+
|
41
53
|
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[torch.Tensor],
|
42
54
|
z: List[torch.Tensor], grad: list[torch.Tensor], r: float = 0.0, step: int = 0):
|
43
55
|
weight = lr ** weight_lr_power * max(step, 1) ** r
|
@@ -50,15 +62,10 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
50
62
|
|
51
63
|
# These operations update y in-place,
|
52
64
|
# without computing x explicitly.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
copy_stochastic_list_(parameters, p32)
|
58
|
-
|
59
|
-
# z step
|
60
|
-
torch._foreach_sub_(z, grad, alpha=lr)
|
61
|
-
copy_stochastic_list_(z, z32)
|
65
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(lr)
|
66
|
+
ckp1_tensor = torch.empty((), dtype=torch.float32, device=parameters[0].device).fill_(ckp1)
|
67
|
+
for p, z_, g in zip(parameters, z, grad):
|
68
|
+
_compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
|
62
69
|
return weight_sum
|
63
70
|
|
64
71
|
|
@@ -504,17 +511,24 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
504
511
|
_compilable_copy_stochastic_(target, source)
|
505
512
|
|
506
513
|
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
514
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
515
|
+
def _compilable_update_one_(p, u, decay, add_fn, lr):
|
516
|
+
p32 = p.float()
|
517
|
+
u32 = u.view(p.shape).float()
|
511
518
|
if decay > 0:
|
512
|
-
|
519
|
+
p32.mul_(1 - decay * lr)
|
513
520
|
if add_fn is None:
|
514
|
-
|
521
|
+
p32.add_(u32, alpha=lr)
|
515
522
|
else:
|
516
|
-
add_fn(
|
517
|
-
|
523
|
+
add_fn(p32, u32, lr)
|
524
|
+
_compilable_copy_stochastic_(p, p32)
|
525
|
+
|
526
|
+
|
527
|
+
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
528
|
+
add_fn: callable = None):
|
529
|
+
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
530
|
+
for p, u in zip(param, update):
|
531
|
+
_compilable_update_one_(p, u, decay, add_fn, lr_tensor)
|
518
532
|
|
519
533
|
|
520
534
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -828,7 +842,10 @@ class PSGDBase(StatefulOptimizer):
|
|
828
842
|
|
829
843
|
for g, q in zip(grad_list, original_q if original_q else q_list):
|
830
844
|
if g.dim() > 1:
|
831
|
-
|
845
|
+
if store_triu_as_line:
|
846
|
+
psgd_balance_Q([q_ for _, q_ in q])
|
847
|
+
else:
|
848
|
+
psgd_balance_Q(q)
|
832
849
|
|
833
850
|
|
834
851
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250):
|
@@ -2,7 +2,7 @@ heavyball/__init__.py,sha256=iqP428JWwwx-XDOZ0nUdbCkOLEyfoqVyWZLQLAcwxaw,2214
|
|
2
2
|
heavyball/cached_delayed_psgd_kron.py,sha256=PQAER6UgVh5l87DGRZrJ8CVP9UhyCG5wJD9rPLnj_G8,6460
|
3
3
|
heavyball/cached_psgd_kron.py,sha256=GaeneBp0irksCSBIrJY4D_0hCpZ-uSRPMhqVX_a-og8,6417
|
4
4
|
heavyball/delayed_psgd.py,sha256=fhBWFLTSl1S2gHWCeYak-STaXRwpC56sWZGLFMKFEJM,5589
|
5
|
-
heavyball/foreach_adamw.py,sha256=
|
5
|
+
heavyball/foreach_adamw.py,sha256=CTg7rfUmlTSjihD5KY9xP0sT2dUKZyZ4-2V42Vlr28U,1780
|
6
6
|
heavyball/foreach_adopt.py,sha256=ogOw2JjwEQNj7AKlweAphQFdMJ_GcMDm-RyDvEzugoc,1911
|
7
7
|
heavyball/foreach_laprop.py,sha256=yGVmGqWiSw8Y2Xj70ndkR8ZMygakTB4_iRwV02Svkqg,1816
|
8
8
|
heavyball/foreach_sfadamw.py,sha256=15-n6-lx4PAHYsKYmXbugxsR5MnqaPYy2vUudPRiitg,2087
|
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=vq7jd302refKPa_9X2lkOTOtCCcTBVByPdo
|
|
16
16
|
heavyball/psgd_kron.py,sha256=u46dorOUXx-do1IYeno2wj-6l1zYKMQQC-N2Zr2PzLI,5476
|
17
17
|
heavyball/pure_psgd.py,sha256=iUy7mMKWxwNiVUMYrQ7SBnreu3t_XSbnhTW3a1yw4m0,4835
|
18
18
|
heavyball/schedule_free_palm_foreach_soap.py,sha256=zkcikH5wWbzq4kOrmBjilvY3iWzuUddcv2HNEPKr3MI,6366
|
19
|
-
heavyball/utils.py,sha256=
|
20
|
-
heavyball-0.18.
|
21
|
-
heavyball-0.18.
|
22
|
-
heavyball-0.18.
|
23
|
-
heavyball-0.18.
|
24
|
-
heavyball-0.18.
|
19
|
+
heavyball/utils.py,sha256=2VBEQhtQ4mwsD99JMu7iWbiYPkutspjG3hGwCbIHZ9U,31134
|
20
|
+
heavyball-0.18.5.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
21
|
+
heavyball-0.18.5.dist-info/METADATA,sha256=Zcc87BhCxDTX7bjJ3pGG7VdIRmpZuYLwmWBKDiLc3AU,11810
|
22
|
+
heavyball-0.18.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
23
|
+
heavyball-0.18.5.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
24
|
+
heavyball-0.18.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|