heavyball 0.22.0__tar.gz → 0.23.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.22.0 → heavyball-0.23.1}/PKG-INFO +2 -2
- {heavyball-0.22.0 → heavyball-0.23.1}/README.md +1 -1
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/delayed_psgd.py +6 -6
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/p_adam.py +2 -2
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/psgd_kron.py +1 -1
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/pure_psgd.py +1 -1
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/utils.py +87 -88
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball.egg-info/PKG-INFO +2 -2
- {heavyball-0.22.0 → heavyball-0.23.1}/setup.py +1 -1
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_bf16_params.py +0 -8
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_bf16_q.py +0 -8
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_bf16_storage.py +0 -6
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_caution.py +0 -9
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_mars.py +3 -11
- {heavyball-0.22.0 → heavyball-0.23.1}/LICENSE +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/__init__.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/cached_delayed_psgd_kron.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/cached_psgd_kron.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/setup.cfg +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_closure.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_ema.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_foreach.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_memory.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_merge.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_no_grad.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_psgd.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_soap.py +0 -0
- {heavyball-0.22.0 → heavyball-0.23.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.23.1
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-26, 0.22.
|
35
|
+
Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-26, 0.22.
|
11
|
+
Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
@@ -5,16 +5,16 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import stochastic_lerp_, beta_debias
|
8
|
+
from heavyball.utils import stochastic_lerp_, beta_debias, stochastic_add_
|
9
9
|
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
triu_to_line, line_to_triu, promote
|
11
|
+
triu_to_line, line_to_triu, promote,_compilable_update_
|
12
12
|
|
13
13
|
|
14
14
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
15
|
-
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr,
|
16
|
-
new = psgd_precond_grad(
|
17
|
-
|
15
|
+
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn, caution, grad):
|
16
|
+
new = psgd_precond_grad(False, exprs, ea, *q)
|
17
|
+
_compilable_update_([p], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
|
18
18
|
|
19
19
|
|
20
20
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -114,7 +114,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
114
114
|
q_orig = Q_list.pop(0)
|
115
115
|
ea = exp_avg_list.pop(0)
|
116
116
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
117
|
-
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
|
117
|
+
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"][-1], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
|
118
118
|
g)
|
119
119
|
if should_update:
|
120
120
|
q32 = [promote(q_) for q_ in q]
|
@@ -110,8 +110,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
110
110
|
|
111
111
|
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
112
112
|
gc = g.clone() if group['caution'] else None
|
113
|
-
psgd_precond_grad(
|
114
|
-
ea = psgd_precond_grad(
|
113
|
+
psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *Q)
|
114
|
+
ea = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *Q)
|
115
115
|
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
116
116
|
torch.div(ea, g, out=g)
|
117
117
|
"""
|
@@ -116,5 +116,5 @@ class ForeachPSGDKron(PSGDBase):
|
|
116
116
|
q32 = [promote(q_) for q_ in q]
|
117
117
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
118
118
|
store_triu_as_line)
|
119
|
-
g = psgd_precond_grad(
|
119
|
+
g = psgd_precond_grad(False, self.state_(p)["exprs"][-1], ea, *q)
|
120
120
|
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|
@@ -101,5 +101,5 @@ class ForeachPurePSGD(PSGDBase):
|
|
101
101
|
if group:
|
102
102
|
q32 = [promote(q_) for q_ in q]
|
103
103
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
104
|
-
psgd_precond_grad(
|
104
|
+
psgd_precond_grad(True, self.state_(p)["exprs"][-1], g, *q)
|
105
105
|
update_param_([p], self.clip_fn([g]), lr, weight_decay, caution=group['caution'], grad=[g])
|
@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple, Callable, Union
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
10
|
+
from torch import Tensor
|
10
11
|
from torch.backends import cudnn, opt_einsum
|
11
12
|
from torch.utils._pytree import tree_map
|
12
13
|
|
@@ -39,15 +40,14 @@ def warmup(lr: float, step: int, warmup_steps: int):
|
|
39
40
|
|
40
41
|
|
41
42
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
42
|
-
def _compilable_schedule_free_(p, z, ckp1, grad, lr, beta1):
|
43
|
-
p32 = promote(p)
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
copy_stochastic_(z, z32)
|
43
|
+
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor, beta1: Tensor):
|
44
|
+
p32, z32, g32 = [promote(x) for x in (p, z, grad)]
|
45
|
+
for p_, z_, g_ in zip(p32, z32, g32):
|
46
|
+
p_.lerp_(z_, ckp1)
|
47
|
+
p_.add_(g_, alpha=lr * (beta1 * (1 - ckp1) - 1))
|
48
|
+
z_.add(g_, alpha=-lr)
|
49
|
+
copy_stochastic_list_(p, p32)
|
50
|
+
copy_stochastic_list_(z, z32)
|
51
51
|
|
52
52
|
|
53
53
|
def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
@@ -61,8 +61,8 @@ def get_ckp1(lr, weight_lr_power, weight_sum, r, step):
|
|
61
61
|
return ckp1, weight_sum
|
62
62
|
|
63
63
|
|
64
|
-
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[
|
65
|
-
z: List[
|
64
|
+
def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
|
65
|
+
z: List[Tensor], grad: list[Tensor], r: float = 0.0, step: int = 0):
|
66
66
|
weight = lr ** weight_lr_power * max(step, 1) ** r
|
67
67
|
weight_sum = weight_sum + weight
|
68
68
|
|
@@ -73,10 +73,8 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
73
73
|
|
74
74
|
# These operations update y in-place,
|
75
75
|
# without computing x explicitly.
|
76
|
-
|
77
|
-
|
78
|
-
for p, z_, g in zip(parameters, z, grad):
|
79
|
-
_compilable_schedule_free_(p, z_, ckp1_tensor, g, lr_tensor, beta1)
|
76
|
+
lr, ckp1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0])
|
77
|
+
_compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1)
|
80
78
|
return weight_sum
|
81
79
|
|
82
80
|
|
@@ -142,27 +140,25 @@ def beta_debias(beta, step):
|
|
142
140
|
|
143
141
|
|
144
142
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
145
|
-
def _compilable_exp_avg_sq_(state, grad, beta2, eps, out
|
143
|
+
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]):
|
146
144
|
torch._foreach_mul_(state, beta2)
|
147
145
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(state, grad)]
|
148
146
|
denom = torch._foreach_sqrt(state)
|
149
147
|
[denom.clamp_(min=eps) for denom in denom]
|
150
|
-
if out is
|
151
|
-
|
152
|
-
return out
|
148
|
+
if out[0] is None:
|
149
|
+
return denom
|
153
150
|
|
154
|
-
|
151
|
+
copy_stochastic_list_(out, denom)
|
152
|
+
return out
|
155
153
|
|
156
154
|
|
157
155
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
158
|
-
state, grad = list_guard(state), list_guard(grad)
|
159
|
-
|
160
|
-
beta2 = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(beta2)
|
161
|
-
if not isinstance(eps, torch.Tensor):
|
162
|
-
eps = torch.empty((), dtype=torch.float32, device=state[0].device).fill_(eps)
|
156
|
+
state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
|
157
|
+
beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
|
163
158
|
return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
|
164
159
|
|
165
|
-
|
160
|
+
|
161
|
+
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
166
162
|
minimum: float = 1e-3, eps: float = 1e-8):
|
167
163
|
if clip_val <= 0:
|
168
164
|
return
|
@@ -183,7 +179,7 @@ def is_compiling():
|
|
183
179
|
return True
|
184
180
|
|
185
181
|
|
186
|
-
def set_(dst:
|
182
|
+
def set_(dst: Tensor, src: Tensor):
|
187
183
|
if not is_compiling() and src.data_ptr() == dst.data_ptr():
|
188
184
|
return
|
189
185
|
if src.shape != dst.shape:
|
@@ -344,7 +340,7 @@ def get_orthogonal_matrix(mat):
|
|
344
340
|
|
345
341
|
|
346
342
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
347
|
-
def _compilable_stochastic_lerp_(x: List[
|
343
|
+
def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
348
344
|
for x_, y_ in zip(x, y):
|
349
345
|
x32 = promote(x_)
|
350
346
|
y32 = promote(y_)
|
@@ -352,10 +348,9 @@ def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a
|
|
352
348
|
copy_stochastic_(x_, x32)
|
353
349
|
|
354
350
|
|
355
|
-
def stochastic_lerp_(x: List[
|
351
|
+
def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
|
356
352
|
x, y = list_guard(x), list_guard(y)
|
357
|
-
|
358
|
-
a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
|
353
|
+
a = scalar_guard(a, x[0])
|
359
354
|
_compilable_stochastic_lerp_(x, y, a)
|
360
355
|
|
361
356
|
|
@@ -365,8 +360,16 @@ def list_guard(x):
|
|
365
360
|
return [x]
|
366
361
|
|
367
362
|
|
363
|
+
def scalar_guard(x, ref):
|
364
|
+
if isinstance(x, float):
|
365
|
+
return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
|
366
|
+
if isinstance(x, int):
|
367
|
+
return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
|
368
|
+
return x
|
369
|
+
|
370
|
+
|
368
371
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
369
|
-
def _compilable_stochastic_add_(x: List[
|
372
|
+
def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
370
373
|
for x_, y_ in zip(x, y):
|
371
374
|
x32 = promote(x_)
|
372
375
|
y32 = promote(y_)
|
@@ -374,10 +377,9 @@ def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], al
|
|
374
377
|
copy_stochastic_(x_, x32)
|
375
378
|
|
376
379
|
|
377
|
-
def stochastic_add_(x: List[
|
380
|
+
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
378
381
|
x, y = list_guard(x), list_guard(y)
|
379
|
-
|
380
|
-
alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
|
382
|
+
alpha = scalar_guard(alpha, x[0])
|
381
383
|
_compilable_stochastic_add_(x, y, alpha)
|
382
384
|
|
383
385
|
|
@@ -399,12 +401,12 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
399
401
|
def promote(x):
|
400
402
|
if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
|
401
403
|
return torch.float32
|
402
|
-
if isinstance(x,
|
404
|
+
if isinstance(x, Tensor) and x.dtype in (torch.bfloat16, torch.float16):
|
403
405
|
return x.float()
|
404
406
|
return x
|
405
407
|
|
406
408
|
|
407
|
-
def min_dtype(xs: List[
|
409
|
+
def min_dtype(xs: List[Tensor]):
|
408
410
|
dtypes = [x.dtype for x in xs]
|
409
411
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
410
412
|
if all(x in (d, torch.float32, torch.float64) for x in dtypes):
|
@@ -470,7 +472,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
470
472
|
self.fake_groups = {}
|
471
473
|
self.use_ema = use_ema
|
472
474
|
|
473
|
-
def key(self, param:
|
475
|
+
def key(self, param: Tensor):
|
474
476
|
return (param.data_ptr(), tuple(param.shape))
|
475
477
|
|
476
478
|
def get_groups(self, group):
|
@@ -483,7 +485,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
483
485
|
|
484
486
|
return [self.fake_groups[self.key(p)] for p in group['params']]
|
485
487
|
|
486
|
-
def state_(self, arg:
|
488
|
+
def state_(self, arg: Tensor):
|
487
489
|
return self.state[self.key(arg)]
|
488
490
|
|
489
491
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
@@ -515,7 +517,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
515
517
|
p_views = merge_group(group, p)
|
516
518
|
if grad is not None:
|
517
519
|
grad = merge_group(group, grad)
|
518
|
-
if isinstance(p_views,
|
520
|
+
if isinstance(p_views, Tensor):
|
519
521
|
yield p_views, grad
|
520
522
|
continue
|
521
523
|
if grad is None:
|
@@ -528,7 +530,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
528
530
|
|
529
531
|
def _add(x):
|
530
532
|
nonlocal total_bytes
|
531
|
-
if isinstance(x,
|
533
|
+
if isinstance(x, Tensor):
|
532
534
|
total_bytes += x.numel() * x.element_size()
|
533
535
|
|
534
536
|
for group in self.param_groups:
|
@@ -636,13 +638,14 @@ class ScheduleFree(StatefulOptimizer):
|
|
636
638
|
raise NotImplementedError
|
637
639
|
|
638
640
|
|
639
|
-
def copy_stochastic_list_(target: List[
|
641
|
+
def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
640
642
|
for t, s in zip(target, source):
|
641
643
|
copy_stochastic_(t, s)
|
642
644
|
|
643
645
|
|
644
646
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
645
|
-
def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad
|
647
|
+
def _compilable_exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
648
|
+
grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor):
|
646
649
|
beta1 = beta_debias(beta1, step)
|
647
650
|
beta2 = beta_debias(beta2, step)
|
648
651
|
|
@@ -655,21 +658,17 @@ def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2
|
|
655
658
|
return denom
|
656
659
|
|
657
660
|
|
658
|
-
def exp_avg_(exp_avg: List[
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
|
664
|
-
if isinstance(step, int):
|
665
|
-
step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
|
661
|
+
def exp_avg_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], grad_projected: List[Tensor],
|
662
|
+
beta1: float, beta2: float, step: int):
|
663
|
+
exp_avg, exp_avg_sq, grad, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(
|
664
|
+
grad), list_guard(grad_projected)
|
665
|
+
beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
|
666
666
|
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
667
667
|
return denom
|
668
668
|
|
669
669
|
|
670
|
-
|
671
|
-
|
672
|
-
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
670
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
671
|
+
def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
|
673
672
|
"""Taken as-is from https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905"""
|
674
673
|
# create a random 16 bit integer
|
675
674
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
@@ -684,40 +683,40 @@ def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
684
683
|
target.copy_(result.view(dtype=torch.float32))
|
685
684
|
|
686
685
|
|
687
|
-
def copy_stochastic_(target:
|
686
|
+
def copy_stochastic_(target: Tensor, source: Tensor):
|
688
687
|
if not is_compiling() and target.data_ptr() == source.data_ptr():
|
689
688
|
return
|
690
|
-
if target.dtype
|
691
|
-
|
692
|
-
|
689
|
+
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
|
690
|
+
_compilable_copy_stochastic_(target, source.float())
|
691
|
+
set_(target, source)
|
693
692
|
|
694
693
|
|
695
694
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
696
|
-
def _compilable_update_(p, u, decay, add_fn, lr, caution,
|
695
|
+
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
|
696
|
+
g: List[Optional[Tensor]]):
|
697
697
|
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
698
|
-
p32, u32
|
698
|
+
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
699
699
|
|
700
700
|
if decay > 0:
|
701
701
|
torch._foreach_mul_(p32, 1 - decay * lr)
|
702
702
|
|
703
|
-
for p32_, u32_,
|
703
|
+
for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
|
704
704
|
if caution:
|
705
|
-
_compilable_cautioning_(
|
706
|
-
|
707
|
-
p32_.add_(u32_, alpha=lr)
|
708
|
-
else:
|
709
|
-
add_fn(p32_, u32_, lr)
|
705
|
+
_compilable_cautioning_(promote(g_), u32_)
|
706
|
+
add_fn(p32_, u32_, lr)
|
710
707
|
|
711
708
|
copy_stochastic_list_(p, p32)
|
712
709
|
|
713
710
|
|
714
|
-
def update_param_(param: List[
|
715
|
-
|
716
|
-
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
711
|
+
def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
|
712
|
+
caution: bool = False, grad: List[Tensor] = None):
|
717
713
|
param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
|
714
|
+
lr = scalar_guard(lr, param[0])
|
718
715
|
if not caution:
|
719
716
|
grad = [None] * len(param)
|
720
|
-
|
717
|
+
if add_fn is None:
|
718
|
+
add_fn = stochastic_add_
|
719
|
+
_compilable_update_(param, update, decay, add_fn, lr, caution, grad)
|
721
720
|
|
722
721
|
|
723
722
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -887,14 +886,14 @@ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
|
887
886
|
|
888
887
|
|
889
888
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
890
|
-
def psgd_precond_grad(
|
889
|
+
def psgd_precond_grad(inplace: bool, exprs: str, grad: Tensor, *preconds: Tensor):
|
891
890
|
"""Precondition gradient G with preconditioner Q."""
|
892
|
-
md = min_dtype(
|
893
|
-
out = torch.einsum(exprs
|
891
|
+
md = min_dtype(preconds)
|
892
|
+
out = torch.einsum(exprs, *[q.conj().to(md) for q in preconds], *[q.to(md) for q in preconds], grad.to(md))
|
894
893
|
if inplace:
|
895
|
-
set_(
|
896
|
-
return
|
897
|
-
return out.to(
|
894
|
+
set_(grad, out)
|
895
|
+
return grad
|
896
|
+
return out.to(grad.dtype)
|
898
897
|
|
899
898
|
|
900
899
|
def norm_clip_(x, scale=None):
|
@@ -957,7 +956,7 @@ def trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
|
|
957
956
|
|
958
957
|
|
959
958
|
@decorator
|
960
|
-
def triu_to_line(Q_list: List[
|
959
|
+
def triu_to_line(Q_list: List[Tensor]):
|
961
960
|
out = []
|
962
961
|
for q in Q_list:
|
963
962
|
if q.dim() < 2:
|
@@ -974,7 +973,7 @@ def _triu_shape(numel):
|
|
974
973
|
|
975
974
|
|
976
975
|
@decorator
|
977
|
-
def line_to_triu(Q_list: List[Tuple[Optional[List[int]],
|
976
|
+
def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
|
978
977
|
new = []
|
979
978
|
for shape, q in Q_list:
|
980
979
|
if shape is not None:
|
@@ -1031,22 +1030,22 @@ class PSGDBase(StatefulOptimizer):
|
|
1031
1030
|
|
1032
1031
|
|
1033
1032
|
# TODO: Figure out why this sometimes crashes
|
1034
|
-
|
1035
|
-
def _compilable_precond_grad_cached_(
|
1033
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1034
|
+
def _compilable_precond_grad_cached_(ea: Tensor, expr: str, param: Tensor, lr: Tensor, weight_decay: Tensor,
|
1035
|
+
clip_fn: callable, caution: bool, grad: Optional[Tensor], *cached_q: Tensor):
|
1036
1036
|
md = min_dtype(cached_q + [ea])
|
1037
1037
|
new = torch.einsum(expr, *[c_.to(md) for c_ in cached_q], ea.to(md)).to(torch.float32)
|
1038
1038
|
update_param_([param], clip_fn([new]), lr, weight_decay, caution=caution, grad=grad)
|
1039
1039
|
|
1040
1040
|
|
1041
|
-
def precond_grad_cached_(cached_q: List[
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
_compilable_precond_grad_cached_(cached_q, ea, expr, param, lr, weight_decay, clip_fn, caution, grad)
|
1041
|
+
def precond_grad_cached_(cached_q: List[Tensor], ea: Tensor, expr: str, param: Tensor, lr: float, weight_decay: float,
|
1042
|
+
clip_fn, caution, grad):
|
1043
|
+
lr = scalar_guard(lr, param)
|
1044
|
+
_compilable_precond_grad_cached_(ea, expr, param, lr, weight_decay, clip_fn, caution, grad, *cached_q)
|
1046
1045
|
|
1047
1046
|
|
1048
1047
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1049
|
-
def _compilable_mars_correction_(g, old_g, a):
|
1048
|
+
def _compilable_mars_correction_(g: Tensor, old_g: Tensor, a: Tensor):
|
1050
1049
|
g_copy = [g_.clone() for g_ in g]
|
1051
1050
|
_compilable_stochastic_lerp_(g, old_g, a)
|
1052
1051
|
copy_stochastic_list_(old_g, g_copy)
|
@@ -1055,12 +1054,12 @@ def _compilable_mars_correction_(g, old_g, a):
|
|
1055
1054
|
def mars_correction(g, old_g, beta1, gamma):
|
1056
1055
|
a = -gamma * beta1 / (1 - beta1)
|
1057
1056
|
g, old_g = list_guard(g), list_guard(old_g)
|
1058
|
-
a =
|
1057
|
+
a = scalar_guard(a, g[0])
|
1059
1058
|
_compilable_mars_correction_(g, old_g, a)
|
1060
1059
|
|
1061
1060
|
|
1062
1061
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
|
1063
|
-
def _compilable_cautioning_(g, update):
|
1062
|
+
def _compilable_cautioning_(g: Tensor, update: Tensor):
|
1064
1063
|
mask = (g * update) > 0
|
1065
1064
|
update.masked_fill_(~mask, 0)
|
1066
1065
|
scale = mask.numel() / mask.sum().clamp(min=1)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.23.1
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-26, 0.22.
|
35
|
+
Currently (2024-11-26, 0.22.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -10,14 +10,6 @@ import torch._inductor.config as ind_cfg
|
|
10
10
|
|
11
11
|
config.cache_size_limit = 128
|
12
12
|
|
13
|
-
|
14
|
-
def get_memory():
|
15
|
-
clean()
|
16
|
-
torch.cuda.synchronize()
|
17
|
-
clean()
|
18
|
-
torch.cuda.synchronize()
|
19
|
-
return torch.cuda.memory_allocated()
|
20
|
-
|
21
13
|
@pytest.mark.parametrize("opt", ['CachedDelayedPSGDKron'])
|
22
14
|
@pytest.mark.parametrize("size,depth", [(256, 1)])
|
23
15
|
def test_foreach(opt, size, depth: int, iterations: int = 16, outer_iterations: int = 3):
|
@@ -11,14 +11,6 @@ from heavyball.utils import clean, set_torch, PSGDBase
|
|
11
11
|
config.cache_size_limit = 128
|
12
12
|
|
13
13
|
|
14
|
-
def get_memory():
|
15
|
-
clean()
|
16
|
-
torch.cuda.synchronize()
|
17
|
-
clean()
|
18
|
-
torch.cuda.synchronize()
|
19
|
-
return torch.cuda.memory_allocated()
|
20
|
-
|
21
|
-
|
22
14
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
23
15
|
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
24
16
|
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
@@ -10,12 +10,6 @@ from heavyball.utils import clean, set_torch, PSGDBase
|
|
10
10
|
|
11
11
|
config.cache_size_limit = 128
|
12
12
|
|
13
|
-
def get_memory():
|
14
|
-
clean()
|
15
|
-
torch.cuda.synchronize()
|
16
|
-
clean()
|
17
|
-
torch.cuda.synchronize()
|
18
|
-
return torch.cuda.memory_allocated()
|
19
13
|
|
20
14
|
|
21
15
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
@@ -9,15 +9,6 @@ from torch._dynamo import config
|
|
9
9
|
|
10
10
|
config.cache_size_limit = 128
|
11
11
|
|
12
|
-
|
13
|
-
def get_memory():
|
14
|
-
clean()
|
15
|
-
torch.cuda.synchronize()
|
16
|
-
clean()
|
17
|
-
torch.cuda.synchronize()
|
18
|
-
return torch.cuda.memory_allocated()
|
19
|
-
|
20
|
-
|
21
12
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
13
|
@pytest.mark.parametrize("size,depth", [(128, 2)])
|
23
14
|
def test_caution(opt, size, depth: int, iterations: int = 65536, outer_iterations: int = 2):
|
@@ -10,17 +10,9 @@ from torch._dynamo import config
|
|
10
10
|
config.cache_size_limit = 128
|
11
11
|
|
12
12
|
|
13
|
-
def get_memory():
|
14
|
-
clean()
|
15
|
-
torch.cuda.synchronize()
|
16
|
-
clean()
|
17
|
-
torch.cuda.synchronize()
|
18
|
-
return torch.cuda.memory_allocated()
|
19
|
-
|
20
|
-
|
21
13
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
14
|
@pytest.mark.parametrize("size,depth", [(128, 2)])
|
23
|
-
def test_mars(opt, size, depth: int, iterations: int =
|
15
|
+
def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations: int = 2):
|
24
16
|
set_torch()
|
25
17
|
opt = getattr(heavyball, opt)
|
26
18
|
if ScheduleFree in opt.__mro__:
|
@@ -35,11 +27,11 @@ def test_mars(opt, size, depth: int, iterations: int = 1024, outer_iterations: i
|
|
35
27
|
losses.append([])
|
36
28
|
|
37
29
|
for i in range(outer_iterations):
|
38
|
-
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
30
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().double()
|
39
31
|
o = get_optim(opt, model.parameters(), lr=1e-5, mars=mars)
|
40
32
|
|
41
33
|
for _ in range(iterations):
|
42
|
-
loss = model(torch.randn((1024, size), device='cuda')).square().mean()
|
34
|
+
loss = model(torch.randn((1024, size), device='cuda', dtype=torch.double)).square().mean()
|
43
35
|
loss.backward()
|
44
36
|
o.step()
|
45
37
|
o.zero_grad()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|