heavyball 0.25.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +192 -29
- heavyball/chainable.py +475 -0
- heavyball/utils.py +334 -180
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/METADATA +4 -3
- heavyball-1.0.0.dist-info/RECORD +8 -0
- heavyball/cached_delayed_psgd_kron.py +0 -135
- heavyball/cached_psgd_kron.py +0 -136
- heavyball/delayed_psgd.py +0 -122
- heavyball/foreach_adamw.py +0 -63
- heavyball/foreach_adopt.py +0 -83
- heavyball/foreach_laprop.py +0 -67
- heavyball/foreach_sfadamw.py +0 -69
- heavyball/foreach_soap.py +0 -91
- heavyball/p_adam.py +0 -121
- heavyball/palm_foreach_sfadamw.py +0 -77
- heavyball/palm_foreach_soap.py +0 -100
- heavyball/precond_schedule_foreach_soap.py +0 -95
- heavyball/precond_schedule_palm_foreach_soap.py +0 -105
- heavyball/precond_schedule_sfpsoap.py +0 -141
- heavyball/psgd_kron.py +0 -120
- heavyball/pure_psgd.py +0 -105
- heavyball/schedule_free_palm_foreach_soap.py +0 -136
- heavyball-0.25.0.dist-info/RECORD +0 -24
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/LICENSE +0 -0
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/WHEEL +0 -0
- {heavyball-0.25.0.dist-info → heavyball-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 1.0.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,11 +32,12 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-
|
35
|
+
Currently (2024-12-07, 1.0.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
39
39
|
|
40
|
+
* **Optax-like API**: `C = heavyball.chainable; grokfast = C.ChainOpt(p, lr, C.exp_avg, C.scale_by_adam)`
|
40
41
|
* **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
|
41
42
|
* **Inplace EMA**: Same math, but less memory, less compute and higher stability
|
42
43
|
* **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
|
@@ -130,6 +131,6 @@ To access `heavyball.utils`, you need to explicitly `import heavyball.utils`.\
|
|
130
131
|
It has several handy functions:
|
131
132
|
|
132
133
|
* `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
|
133
|
-
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls
|
134
|
+
* `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
|
134
135
|
* `zeroth_power_mode`, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate
|
135
136
|
the eigenvectors. Eigh has the highest precision and cost
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=1QPYBIH8amnk3-_rKe6L9FJ0rkV5wVNRr7Yw9BXjIYI,11636
|
2
|
+
heavyball/chainable.py,sha256=cp-tpetPr4CNN9xJ85JSo89JYC5BWUygoE6dnET6tmc,18141
|
3
|
+
heavyball/utils.py,sha256=qUoB9EIxl7GUyLkV5a5JAKOD6TvPc1FNsqyUbJ-HY6o,46343
|
4
|
+
heavyball-1.0.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.0.0.dist-info/METADATA,sha256=9C2btIxngp26TRCJFU6B8ftkWQt1rfZZC10rkAhaORw,12074
|
6
|
+
heavyball-1.0.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.0.0.dist-info/RECORD,,
|
@@ -1,135 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
-
Modified under Creative Commons Attribution 4.0 International
|
4
|
-
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
-
"""
|
6
|
-
|
7
|
-
from typing import Optional
|
8
|
-
|
9
|
-
import torch
|
10
|
-
from heavyball.utils import min_dtype, precond_grad_cached_
|
11
|
-
|
12
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
-
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
14
|
-
|
15
|
-
|
16
|
-
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
17
|
-
"""
|
18
|
-
Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP) with cached preconditioners.
|
19
|
-
|
20
|
-
|
21
|
-
Args:
|
22
|
-
params (iterable): Iterable of parameters to optimize or dicts defining
|
23
|
-
parameter groups.
|
24
|
-
lr (float): Learning rate.
|
25
|
-
beta (float): Momentum parameter.
|
26
|
-
weight_decay (float): Weight decay (L2 penalty).
|
27
|
-
preconditioner_update_probability (callable or float, optional): Probability of
|
28
|
-
updating the preconditioner. If None, defaults to a schedule that anneals
|
29
|
-
from 1.0 to 0.03 by 4000 steps.
|
30
|
-
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
31
|
-
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
32
|
-
to have triangular preconditioners.
|
33
|
-
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
34
|
-
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
35
|
-
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
36
|
-
to be diagonal.
|
37
|
-
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
38
|
-
update instead of raw gradients.
|
39
|
-
"""
|
40
|
-
|
41
|
-
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
42
|
-
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
43
|
-
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
44
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
45
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
46
|
-
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
47
|
-
#
|
48
|
-
# expert parameters
|
49
|
-
precond_init_scale=1.0, precond_lr=0.1):
|
50
|
-
if not 0.0 <= lr:
|
51
|
-
raise ValueError(f"Invalid learning rate: {lr}")
|
52
|
-
if not 0.0 <= beta < 1.0:
|
53
|
-
raise ValueError(f"Invalid beta parameter: {beta}")
|
54
|
-
if not 0.0 <= weight_decay:
|
55
|
-
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
56
|
-
|
57
|
-
if clip_fn is None:
|
58
|
-
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
59
|
-
|
60
|
-
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
61
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
62
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
63
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
64
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
65
|
-
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
|
66
|
-
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
67
|
-
|
68
|
-
def _step(self, group):
|
69
|
-
should_update = self.should_update(group)
|
70
|
-
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
71
|
-
precond_init_scale = group['precond_init_scale']
|
72
|
-
max_size_triangular = group['max_size_triangular']
|
73
|
-
min_ndim_triangular = group['min_ndim_triangular']
|
74
|
-
memory_save_mode = group['memory_save_mode']
|
75
|
-
precond_lr = group['precond_lr']
|
76
|
-
weight_decay = group['weight_decay']
|
77
|
-
lr = group['lr']
|
78
|
-
beta = group['beta']
|
79
|
-
store_triu_as_line = group['store_triu_as_line']
|
80
|
-
q_dtype = getattr(torch, group['q_dtype'])
|
81
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
82
|
-
|
83
|
-
vals = []
|
84
|
-
|
85
|
-
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
86
|
-
state = self.state_(p)
|
87
|
-
|
88
|
-
if 'Q' not in state:
|
89
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
|
90
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
91
|
-
memory_save_mode, dtype=q_dtype)
|
92
|
-
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
93
|
-
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
94
|
-
|
95
|
-
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
96
|
-
expr = ','.join(expr)
|
97
|
-
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
98
|
-
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
99
|
-
expr = f'{expr},{grad_expr}->{out_expr}'
|
100
|
-
|
101
|
-
state['cache_expr'] = expr
|
102
|
-
|
103
|
-
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
104
|
-
|
105
|
-
if not vals:
|
106
|
-
return
|
107
|
-
|
108
|
-
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
109
|
-
del vals
|
110
|
-
|
111
|
-
group["step"] += 1
|
112
|
-
|
113
|
-
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
114
|
-
|
115
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
116
|
-
|
117
|
-
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
118
|
-
exp_avg_list)
|
119
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
120
|
-
cached_q = Q_cache_list.pop(0)
|
121
|
-
q_orig = Q_list.pop(0)
|
122
|
-
ea = exp_avg_list.pop(0)
|
123
|
-
|
124
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn, group['caution'], g)
|
125
|
-
|
126
|
-
if should_update:
|
127
|
-
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
128
|
-
q32 = [promote(q_) for q_ in q]
|
129
|
-
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
130
|
-
store_triu_as_line)
|
131
|
-
for c_, q_ in zip(cached_q, q):
|
132
|
-
if q_.ndim == 2:
|
133
|
-
torch.matmul(q_.T.conj(), q_, out=c_)
|
134
|
-
else:
|
135
|
-
torch.mul(q_.conj(), q_, out=c_)
|
heavyball/cached_psgd_kron.py
DELETED
@@ -1,136 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
-
Modified under Creative Commons Attribution 4.0 International
|
4
|
-
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
-
"""
|
6
|
-
|
7
|
-
from typing import Optional
|
8
|
-
|
9
|
-
import torch
|
10
|
-
|
11
|
-
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias, precond_grad_cached_
|
13
|
-
|
14
|
-
|
15
|
-
class ForeachCachedPSGDKron(PSGDBase):
|
16
|
-
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
|
17
|
-
|
18
|
-
Args:
|
19
|
-
params (iterable): Iterable of parameters to optimize or dicts defining
|
20
|
-
parameter groups.
|
21
|
-
lr (float): Learning rate.
|
22
|
-
beta (float): Momentum parameter.
|
23
|
-
weight_decay (float): Weight decay (L2 penalty).
|
24
|
-
preconditioner_update_probability (callable or float, optional): Probability of
|
25
|
-
updating the preconditioner. If None, defaults to a schedule that anneals
|
26
|
-
from 1.0 to 0.03 by 4000 steps.
|
27
|
-
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
28
|
-
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
29
|
-
to have triangular preconditioners.
|
30
|
-
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
31
|
-
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
32
|
-
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
33
|
-
to be diagonal.
|
34
|
-
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
35
|
-
update instead of raw gradients.
|
36
|
-
"""
|
37
|
-
|
38
|
-
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
|
-
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
|
-
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
-
storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
|
44
|
-
orthogonalize_output: bool = False,
|
45
|
-
#
|
46
|
-
# expert parameters
|
47
|
-
precond_init_scale=1.0, precond_lr=0.1):
|
48
|
-
if not 0.0 <= lr:
|
49
|
-
raise ValueError(f"Invalid learning rate: {lr}")
|
50
|
-
if not 0.0 <= beta < 1.0:
|
51
|
-
raise ValueError(f"Invalid beta parameter: {beta}")
|
52
|
-
if not 0.0 <= weight_decay:
|
53
|
-
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
54
|
-
|
55
|
-
if clip_fn is None:
|
56
|
-
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
57
|
-
|
58
|
-
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
59
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
60
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
61
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
62
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
63
|
-
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
64
|
-
orthogonalize_output=orthogonalize_output)
|
65
|
-
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
66
|
-
|
67
|
-
def _step(self, group):
|
68
|
-
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
69
|
-
precond_init_scale = group['precond_init_scale']
|
70
|
-
max_size_triangular = group['max_size_triangular']
|
71
|
-
min_ndim_triangular = group['min_ndim_triangular']
|
72
|
-
memory_save_mode = group['memory_save_mode']
|
73
|
-
precond_lr = group['precond_lr']
|
74
|
-
weight_decay = group['weight_decay']
|
75
|
-
lr = group['lr']
|
76
|
-
beta = group['beta']
|
77
|
-
store_triu_as_line = group['store_triu_as_line']
|
78
|
-
q_dtype = getattr(torch, group['q_dtype'])
|
79
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
80
|
-
orthogonalize_output = group['orthogonalize_output']
|
81
|
-
should_update = self.should_update(group)
|
82
|
-
|
83
|
-
vals = []
|
84
|
-
|
85
|
-
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
86
|
-
state = self.state_(p)
|
87
|
-
|
88
|
-
if 'Q' not in state:
|
89
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
|
90
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
91
|
-
memory_save_mode, dtype=q_dtype)
|
92
|
-
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
93
|
-
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
94
|
-
|
95
|
-
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
96
|
-
expr = ','.join(expr)
|
97
|
-
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
98
|
-
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
99
|
-
expr = f'{expr},{grad_expr}->{out_expr}'
|
100
|
-
|
101
|
-
state['cache_expr'] = expr
|
102
|
-
|
103
|
-
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
104
|
-
|
105
|
-
if not vals:
|
106
|
-
return
|
107
|
-
|
108
|
-
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
109
|
-
del vals
|
110
|
-
|
111
|
-
group["step"] += 1
|
112
|
-
|
113
|
-
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
114
|
-
|
115
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
116
|
-
|
117
|
-
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
118
|
-
exp_avg_list)
|
119
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
120
|
-
cached_q = Q_cache_list.pop(0)
|
121
|
-
q_orig = Q_list.pop(0)
|
122
|
-
ea = exp_avg_list.pop(0)
|
123
|
-
|
124
|
-
if should_update:
|
125
|
-
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
126
|
-
q32 = [promote(q_) for q_ in q]
|
127
|
-
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
128
|
-
store_triu_as_line)
|
129
|
-
for c_, q_ in zip(cached_q, q):
|
130
|
-
if q_.ndim == 2:
|
131
|
-
torch.matmul(q_.T.conj(), q_, out=c_)
|
132
|
-
else:
|
133
|
-
torch.mul(q_.conj(), q_, out=c_)
|
134
|
-
|
135
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay, self.clip_fn,
|
136
|
-
group['caution'], g)
|
heavyball/delayed_psgd.py
DELETED
@@ -1,122 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
-
Modified under Creative Commons Attribution 4.0 International
|
4
|
-
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
-
"""
|
6
|
-
|
7
|
-
import torch
|
8
|
-
from heavyball.utils import stochastic_lerp_, beta_debias, stochastic_add_
|
9
|
-
|
10
|
-
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
triu_to_line, line_to_triu, promote,_compilable_update_, decorator_knowngood
|
12
|
-
|
13
|
-
|
14
|
-
@decorator_knowngood
|
15
|
-
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay, clip_fn, caution, grad):
|
16
|
-
new = psgd_precond_grad(False, exprs, ea, *q)
|
17
|
-
_compilable_update_([p], clip_fn([new]), weight_decay, stochastic_add_, lr, caution, [grad])
|
18
|
-
|
19
|
-
|
20
|
-
class ForeachDelayedPSGD(PSGDBase):
|
21
|
-
"""
|
22
|
-
Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP)
|
23
|
-
|
24
|
-
Args:
|
25
|
-
params (iterable): Iterable of parameters to optimize or dicts defining
|
26
|
-
parameter groups.
|
27
|
-
lr (float): Learning rate.
|
28
|
-
beta (float): Momentum parameter.
|
29
|
-
weight_decay (float): Weight decay (L2 penalty).
|
30
|
-
preconditioner_update_probability (callable or float, optional): Probability of
|
31
|
-
updating the preconditioner. If None, defaults to a schedule that anneals
|
32
|
-
from 1.0 to 0.03 by 4000 steps.
|
33
|
-
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
34
|
-
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
35
|
-
to have triangular preconditioners.
|
36
|
-
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
37
|
-
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
38
|
-
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
39
|
-
to be diagonal.
|
40
|
-
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
41
|
-
update instead of raw gradients.
|
42
|
-
"""
|
43
|
-
|
44
|
-
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
45
|
-
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
46
|
-
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
47
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
48
|
-
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype: str = 'float32',
|
49
|
-
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, #
|
50
|
-
# expert parameters
|
51
|
-
precond_init_scale=1.0, precond_lr=0.1):
|
52
|
-
if not 0.0 <= lr:
|
53
|
-
raise ValueError(f"Invalid learning rate: {lr}")
|
54
|
-
if not 0.0 <= beta < 1.0:
|
55
|
-
raise ValueError(f"Invalid beta parameter: {beta}")
|
56
|
-
if not 0.0 <= weight_decay:
|
57
|
-
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
58
|
-
|
59
|
-
if clip_fn is None:
|
60
|
-
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
61
|
-
|
62
|
-
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
63
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
64
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
65
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
66
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
67
|
-
storage_dtype=storage_dtype,
|
68
|
-
caution=caution, mars_gamma=mars_gamma, mars=mars)
|
69
|
-
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
70
|
-
|
71
|
-
def _step(self, group):
|
72
|
-
should_update = self.should_update(group)
|
73
|
-
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
74
|
-
precond_init_scale = group['precond_init_scale']
|
75
|
-
max_size_triangular = group['max_size_triangular']
|
76
|
-
min_ndim_triangular = group['min_ndim_triangular']
|
77
|
-
memory_save_mode = group['memory_save_mode']
|
78
|
-
precond_lr = group['precond_lr']
|
79
|
-
weight_decay = group['weight_decay']
|
80
|
-
lr = group['lr']
|
81
|
-
beta = group['beta']
|
82
|
-
store_triu_as_line = group['store_triu_as_line']
|
83
|
-
q_dtype = getattr(torch, group['q_dtype'])
|
84
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
85
|
-
|
86
|
-
vals = []
|
87
|
-
|
88
|
-
for p, g in self.split_p_and_g_in_group(group, should_promote=False, beta1=beta):
|
89
|
-
state = self.state_(p)
|
90
|
-
|
91
|
-
if 'Q' not in state:
|
92
|
-
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype, memory_format=torch.preserve_format)
|
93
|
-
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
94
|
-
memory_save_mode, dtype=q_dtype)
|
95
|
-
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
96
|
-
|
97
|
-
vals.append((p, g, state["exp_avg"], state["Q"]))
|
98
|
-
|
99
|
-
if not vals:
|
100
|
-
return
|
101
|
-
|
102
|
-
p_list, grad_list, exp_avg_list, Q_list = zip(*vals)
|
103
|
-
del vals
|
104
|
-
|
105
|
-
group["step"] += 1
|
106
|
-
|
107
|
-
stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
|
108
|
-
|
109
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
110
|
-
lr = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(lr)
|
111
|
-
|
112
|
-
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
113
|
-
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
114
|
-
q_orig = Q_list.pop(0)
|
115
|
-
ea = exp_avg_list.pop(0)
|
116
|
-
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
117
|
-
_compilable_psgd_precond_grad_(q, self.state_(p)["exprs"][-1], ea, p, lr, weight_decay, self.clip_fn, group['caution'],
|
118
|
-
g)
|
119
|
-
if should_update:
|
120
|
-
q32 = [promote(q_) for q_ in q]
|
121
|
-
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
122
|
-
store_triu_as_line)
|
heavyball/foreach_adamw.py
DELETED
@@ -1,63 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
from heavyball.utils import copy_stochastic_list_
|
4
|
-
|
5
|
-
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, decorator_knowngood
|
6
|
-
|
7
|
-
|
8
|
-
@decorator_knowngood
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
10
|
-
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
|
-
|
12
|
-
torch._foreach_lerp_(exp_avg32, g32, 1 - beta_debias(beta1, step + 1))
|
13
|
-
denom = list(exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step + 1), eps))
|
14
|
-
|
15
|
-
update_param_(y, exp_avg32, lr, decay, lambda p, e, l: p.addcdiv_(e, denom.pop(0), value=l), caution=caution,
|
16
|
-
grad=g32)
|
17
|
-
|
18
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
19
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
20
|
-
|
21
|
-
|
22
|
-
class ForeachAdamW(StatefulOptimizer):
|
23
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
24
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
25
|
-
mars_gamma: float = 0.0025):
|
26
|
-
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
27
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
28
|
-
mars_gamma=mars_gamma)
|
29
|
-
super().__init__(params, defaults, foreach)
|
30
|
-
|
31
|
-
def _step(self, group):
|
32
|
-
eps = group['eps']
|
33
|
-
decay = group['weight_decay']
|
34
|
-
k = group['k']
|
35
|
-
|
36
|
-
if not group['train_mode']:
|
37
|
-
raise Exception("Not in train mode!")
|
38
|
-
|
39
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
40
|
-
|
41
|
-
if not active_p:
|
42
|
-
return
|
43
|
-
|
44
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
45
|
-
|
46
|
-
for p in active_p:
|
47
|
-
if 'exp_avg' not in self.state_(p):
|
48
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
49
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
50
|
-
|
51
|
-
y, grad, exp_avg_sq, exp_avg = zip(
|
52
|
-
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
53
|
-
|
54
|
-
if group['mars']:
|
55
|
-
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
56
|
-
|
57
|
-
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
58
|
-
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
59
|
-
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
60
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
|
61
|
-
group['caution'])
|
62
|
-
|
63
|
-
group['k'] = k + 1
|
heavyball/foreach_adopt.py
DELETED
@@ -1,83 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
from heavyball.utils import copy_stochastic_list_
|
4
|
-
|
5
|
-
from .utils import warmup, beta_debias, update_param_, StatefulOptimizer, promote, decorator_knowngood
|
6
|
-
|
7
|
-
|
8
|
-
@decorator_knowngood
|
9
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
10
|
-
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
11
|
-
update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
|
12
|
-
|
13
|
-
beta1 = beta_debias(beta1, step)
|
14
|
-
denom = torch._foreach_sqrt(exp_avg_sq32)
|
15
|
-
torch._foreach_maximum_(denom, eps)
|
16
|
-
torch._foreach_mul_(exp_avg32, beta1)
|
17
|
-
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
18
|
-
|
19
|
-
beta2 = beta_debias(beta2, step + 1)
|
20
|
-
torch._foreach_mul_(exp_avg_sq32, beta2)
|
21
|
-
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
22
|
-
|
23
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
24
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
25
|
-
|
26
|
-
|
27
|
-
class ForeachADOPT(StatefulOptimizer):
|
28
|
-
|
29
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
30
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
31
|
-
mars_gamma: float = 0.0025):
|
32
|
-
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
33
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
34
|
-
mars_gamma=mars_gamma)
|
35
|
-
super().__init__(params, defaults, foreach)
|
36
|
-
|
37
|
-
def _step(self, group):
|
38
|
-
eps = group['eps']
|
39
|
-
decay = group['weight_decay']
|
40
|
-
k = group['k']
|
41
|
-
|
42
|
-
if not group['train_mode']:
|
43
|
-
raise Exception("Not in train mode!")
|
44
|
-
|
45
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
46
|
-
|
47
|
-
if not active_p:
|
48
|
-
return
|
49
|
-
|
50
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
51
|
-
|
52
|
-
for p in active_p:
|
53
|
-
if 'exp_avg' not in self.state_(p):
|
54
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
55
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
56
|
-
|
57
|
-
y, grad, exp_avg_sq, exp_avg = zip(
|
58
|
-
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) for p in active_p])
|
59
|
-
|
60
|
-
group['k'] = k + 1
|
61
|
-
|
62
|
-
if group['mars']:
|
63
|
-
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
64
|
-
|
65
|
-
if k > 1:
|
66
|
-
lr = -warmup(group['lr'], k - 1, group['warmup_steps'])
|
67
|
-
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
68
|
-
k = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k)
|
69
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], k, lr, eps, decay, group['caution'])
|
70
|
-
return
|
71
|
-
|
72
|
-
grad = [promote(g) for g in grad]
|
73
|
-
if k > 0:
|
74
|
-
beta1 = beta_debias(group['betas'][0], k)
|
75
|
-
denom = torch._foreach_sqrt(exp_avg_sq)
|
76
|
-
torch._foreach_maximum_(denom, eps)
|
77
|
-
torch._foreach_mul_(exp_avg, beta1)
|
78
|
-
torch._foreach_addcdiv_(exp_avg, grad, denom, 1 - beta1)
|
79
|
-
|
80
|
-
beta2 = beta_debias(group['betas'][1], k + 1)
|
81
|
-
torch._foreach_mul_(exp_avg_sq, beta2)
|
82
|
-
torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2)
|
83
|
-
del grad
|
heavyball/foreach_laprop.py
DELETED
@@ -1,67 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.optim
|
3
|
-
|
4
|
-
from .utils import warmup, exp_avg_sq_, beta_debias, update_param_, StatefulOptimizer, promote, copy_stochastic_list_, decorator_knowngood
|
5
|
-
|
6
|
-
|
7
|
-
@decorator_knowngood
|
8
|
-
def _compilable_step_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
9
|
-
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
10
|
-
|
11
|
-
denom = exp_avg_sq_(exp_avg_sq32, g32, beta_debias(beta2, step), eps)
|
12
|
-
|
13
|
-
beta1 = beta_debias(beta1, step)
|
14
|
-
torch._foreach_mul_(exp_avg32, beta1)
|
15
|
-
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
16
|
-
|
17
|
-
update_param_(y, exp_avg32, lr, decay, caution=caution, grad=g32)
|
18
|
-
|
19
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
20
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
21
|
-
|
22
|
-
|
23
|
-
class ForeachLaProp(StatefulOptimizer):
|
24
|
-
|
25
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=1,
|
26
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
27
|
-
mars_gamma: float = 0.0025):
|
28
|
-
defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
|
29
|
-
lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
|
30
|
-
mars_gamma=mars_gamma)
|
31
|
-
super().__init__(params, defaults, foreach)
|
32
|
-
|
33
|
-
def _step(self, group):
|
34
|
-
eps = group['eps']
|
35
|
-
decay = group['weight_decay']
|
36
|
-
k = group['k']
|
37
|
-
|
38
|
-
if not group['train_mode']:
|
39
|
-
raise Exception("Not in train mode!")
|
40
|
-
|
41
|
-
active_p = [p for p in group['params'] if p.grad is not None]
|
42
|
-
|
43
|
-
if not active_p:
|
44
|
-
return
|
45
|
-
|
46
|
-
storage_dtype = getattr(torch, group['storage_dtype'])
|
47
|
-
|
48
|
-
for p in active_p:
|
49
|
-
if 'exp_avg' not in self.state_(p):
|
50
|
-
self.state_(p)['exp_avg'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
51
|
-
self.state_(p)['exp_avg_sq'] = torch.zeros_like(p.data, dtype=storage_dtype, memory_format=torch.preserve_format)
|
52
|
-
|
53
|
-
y, grad, exp_avg_sq, exp_avg = zip(
|
54
|
-
*[(p.data, p.grad, self.state_(p)['exp_avg_sq'], self.state_(p)['exp_avg']) #
|
55
|
-
for p in active_p])
|
56
|
-
|
57
|
-
if group['mars']:
|
58
|
-
self.mars_correct_list(group, y, grad, group['mars_gamma'], group['betas'][0])
|
59
|
-
|
60
|
-
lr = -warmup(group['lr'], k + 1, group['warmup_steps'])
|
61
|
-
lr = torch.empty((), dtype=torch.float32, device=y[0].device).fill_(lr)
|
62
|
-
step = torch.empty((), dtype=torch.int32, device=y[0].device).fill_(k + 1)
|
63
|
-
|
64
|
-
_compilable_step_(y, grad, exp_avg_sq, exp_avg, group['betas'][0], group['betas'][1], step, lr, eps, decay,
|
65
|
-
group['caution'])
|
66
|
-
|
67
|
-
group['k'] = k + 1
|