heavyball 1.5.3__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +63 -20
- heavyball/chainable.py +15 -12
- heavyball/utils.py +264 -197
- {heavyball-1.5.3.dist-info → heavyball-1.6.1.dist-info}/METADATA +4 -6
- heavyball-1.6.1.dist-info/RECORD +8 -0
- heavyball-1.5.3.dist-info/RECORD +0 -8
- {heavyball-1.5.3.dist-info → heavyball-1.6.1.dist-info}/LICENSE +0 -0
- {heavyball-1.5.3.dist-info → heavyball-1.6.1.dist-info}/WHEEL +0 -0
- {heavyball-1.5.3.dist-info → heavyball-1.6.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
|
|
104
104
|
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
105
105
|
https://arxiv.org/abs/2409.11321
|
106
106
|
https://github.com/nikhilvyas/SOAP
|
107
|
-
|
108
|
-
ScheduleFree:
|
109
|
-
The Road Less Scheduled
|
110
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
111
|
-
https://arxiv.org/abs/2405.15682
|
112
|
-
https://github.com/facebookresearch/schedule_free
|
113
107
|
"""
|
114
108
|
use_precond_schedule: bool = False
|
115
109
|
|
116
110
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
117
111
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
118
112
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
113
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
114
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
115
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
116
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
117
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
124
118
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
125
119
|
|
126
120
|
defaults = locals()
|
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
|
|
137
131
|
C.scale_by_soap)
|
138
132
|
|
139
133
|
|
134
|
+
class ForeachSignLaProp(C.BaseOpt):
|
135
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
136
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
137
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
138
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
139
|
+
defaults = locals()
|
140
|
+
defaults.pop("self")
|
141
|
+
params = defaults.pop("params")
|
142
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
|
143
|
+
|
144
|
+
|
145
|
+
class ForeachSOLP(C.BaseOpt):
|
146
|
+
"""
|
147
|
+
ForeachSOLP
|
148
|
+
|
149
|
+
Sources:
|
150
|
+
Baseline SOAP:
|
151
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
152
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
153
|
+
https://arxiv.org/abs/2409.11321
|
154
|
+
https://github.com/nikhilvyas/SOAP
|
155
|
+
"""
|
156
|
+
use_precond_schedule: bool = False
|
157
|
+
|
158
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
159
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
160
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
161
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
162
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
163
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
164
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
165
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
166
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
167
|
+
|
168
|
+
defaults = locals()
|
169
|
+
defaults.pop("self")
|
170
|
+
params = defaults.pop("params")
|
171
|
+
|
172
|
+
if use_precond_schedule:
|
173
|
+
del defaults['precondition_frequency']
|
174
|
+
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
175
|
+
else:
|
176
|
+
del defaults['precond_scheduler']
|
177
|
+
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
178
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
179
|
+
functools.partial(C.scale_by_soap, inner='laprop'))
|
180
|
+
|
181
|
+
|
140
182
|
class PaLMForeachSOAP(ForeachSOAP):
|
141
183
|
use_precond_schedule: bool = False
|
142
184
|
palm: bool = True
|
@@ -163,7 +205,6 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
205
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
206
|
|
165
207
|
|
166
|
-
|
167
208
|
class LaPropOrtho(C.BaseOpt):
|
168
209
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
169
210
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
@@ -172,8 +213,8 @@ class LaPropOrtho(C.BaseOpt):
|
|
172
213
|
defaults = locals()
|
173
214
|
defaults.pop("self")
|
174
215
|
params = defaults.pop("params")
|
175
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
176
|
-
C.
|
216
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
|
217
|
+
C.orthogonalize_grad_to_param)
|
177
218
|
|
178
219
|
|
179
220
|
class ForeachPSGDKron(C.BaseOpt):
|
@@ -191,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
191
232
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
192
233
|
momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
|
193
234
|
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
194
|
-
stochastic_schedule: bool =
|
235
|
+
stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
|
195
236
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
196
237
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
197
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
238
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
198
239
|
# expert parameters
|
199
240
|
precond_init_scale=1.0, precond_lr=0.1):
|
200
241
|
defaults = locals()
|
@@ -251,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
|
|
251
292
|
CachedPSGDKron = ForeachCachedPSGDKron
|
252
293
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
253
294
|
Muon = ForeachMuon
|
295
|
+
SignLaProp = ForeachSignLaProp
|
254
296
|
|
255
297
|
__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
256
298
|
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
257
|
-
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
258
|
-
|
259
|
-
"
|
260
|
-
"
|
299
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
|
300
|
+
"ForeachAdamW", "ForeachSFAdamW",
|
301
|
+
"ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
|
302
|
+
"ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
|
303
|
+
'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
|
heavyball/chainable.py
CHANGED
@@ -127,7 +127,7 @@ def zero_guard(*names):
|
|
127
127
|
|
128
128
|
|
129
129
|
def copy_guard(index, *names):
|
130
|
-
return functools.partial(CopyGuard, index=index, names=names
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names)
|
131
131
|
|
132
132
|
|
133
133
|
def general_guard(*names, init_fn, skip_first: bool = True):
|
@@ -222,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
222
222
|
@no_state
|
223
223
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
224
224
|
if group['step'] == 1:
|
225
|
-
utils.
|
225
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
226
226
|
raise SkipUpdate
|
227
227
|
|
228
228
|
if group['step'] == 2:
|
229
229
|
update = utils.promote(update)
|
230
230
|
easq = utils.promote(exp_avg_sq)
|
231
231
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
232
|
-
utils.
|
232
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
233
|
+
group['eps'])
|
233
234
|
raise SkipUpdate
|
234
235
|
|
235
236
|
utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
@@ -241,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
241
242
|
@no_state
|
242
243
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
243
244
|
if group['step'] == 1:
|
244
|
-
utils.
|
245
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
245
246
|
raise SkipUpdate
|
246
247
|
|
247
248
|
if group['step'] == 2:
|
248
249
|
update = utils.promote(update)
|
249
250
|
easq = utils.promote(exp_avg_sq)
|
250
251
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
251
|
-
utils.
|
252
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
253
|
+
group['eps'])
|
252
254
|
raise SkipUpdate
|
253
255
|
|
254
256
|
return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
|
255
257
|
|
256
258
|
|
257
|
-
def _init_soap(state, group, update, grad, param):
|
258
|
-
utils.init_preconditioner(grad, state,
|
259
|
+
def _init_soap(state, group, update, grad, param, inner: str = ''):
|
260
|
+
utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
|
259
261
|
|
260
262
|
|
261
263
|
def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
@@ -343,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
|
|
343
345
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
344
346
|
@no_state
|
345
347
|
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
|
346
|
-
update = utils.promote(update)
|
348
|
+
update = utils.promote(update) # Promote to highest precision if needed
|
347
349
|
|
348
350
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
349
351
|
fn = _optim_fns[inner]
|
350
|
-
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'],
|
352
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
|
353
|
+
group['eps'])
|
351
354
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
352
355
|
|
353
|
-
for u, q, gg,
|
354
|
-
utils.update_preconditioner(u, q, gg,
|
356
|
+
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
357
|
+
utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
|
355
358
|
utils.beta_debias(group['shampoo_beta'], group['step']),
|
356
359
|
group['is_preconditioning'])
|
357
360
|
return precond
|
@@ -503,7 +506,7 @@ def create_branch(branches: List[List[callable]], merge_fn: callable):
|
|
503
506
|
def _branch(state, group, update, grad, param):
|
504
507
|
outputs = []
|
505
508
|
for branch in branches:
|
506
|
-
branch_update = [torch.clone(
|
509
|
+
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
507
510
|
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
508
511
|
if skip_update:
|
509
512
|
raise ValueError("Branches should not skip updates")
|
heavyball/utils.py
CHANGED
@@ -5,20 +5,30 @@ import random
|
|
5
5
|
import string
|
6
6
|
import warnings
|
7
7
|
from typing import List, Optional, Tuple, Callable, Union
|
8
|
+
from unittest.mock import patch
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import torch
|
11
12
|
from torch import Tensor
|
13
|
+
from torch._dynamo import config
|
12
14
|
from torch._dynamo.exc import TorchDynamoException
|
13
15
|
from torch.backends import cudnn, opt_einsum
|
14
16
|
from torch.utils._pytree import tree_map
|
15
17
|
|
18
|
+
config.cache_size_limit = 2 ** 16
|
19
|
+
|
20
|
+
np.warnings = warnings
|
21
|
+
|
16
22
|
compile_mode = "max-autotune-no-cudagraphs"
|
17
23
|
dynamic = False
|
18
24
|
compile_mode_recommended_to_none = None
|
19
25
|
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
20
26
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
21
27
|
|
28
|
+
base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
|
29
|
+
'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
|
30
|
+
'weight_decay': 1e-4}
|
31
|
+
|
22
32
|
|
23
33
|
def decorator(func):
|
24
34
|
compiled = None
|
@@ -51,7 +61,7 @@ def decorator_knowngood(func: Callable):
|
|
51
61
|
return _fn
|
52
62
|
|
53
63
|
|
54
|
-
einsum_base = string.ascii_lowercase
|
64
|
+
einsum_base = string.ascii_lowercase
|
55
65
|
|
56
66
|
|
57
67
|
@decorator_knowngood
|
@@ -103,29 +113,29 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
103
113
|
but we want to merge conv kernels into fan-in or at least merge the kernel
|
104
114
|
so, [128, 64, 3, 3] should result in [128, 576] or [128, 64, 9] instead of [73728] or [8192, 3, 3] the baseline
|
105
115
|
would've done
|
116
|
+
|
117
|
+
By @francois-rozet (commit: 68cde41eaf7e73b4c46eacb6a944865dcc081f1d), re-commited due to faulty merge
|
106
118
|
"""
|
107
|
-
shape = grad.shape
|
108
119
|
new_shape = []
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
curr_shape = sh
|
120
|
+
cum_size = 1
|
121
|
+
|
122
|
+
for s in grad.shape[1:][::-1]:
|
123
|
+
temp_size = cum_size * s
|
124
|
+
if temp_size > max_precond_dim:
|
125
|
+
if cum_size > 1:
|
126
|
+
new_shape.append(cum_size)
|
127
|
+
cum_size = s
|
118
128
|
else:
|
119
|
-
new_shape.append(
|
120
|
-
|
129
|
+
new_shape.append(s)
|
130
|
+
cum_size = 1
|
121
131
|
else:
|
122
|
-
|
123
|
-
new_shape = [*shape[:1], *new_shape[::-1]]
|
132
|
+
cum_size = temp_size
|
124
133
|
|
125
|
-
if
|
126
|
-
new_shape.append(
|
134
|
+
if cum_size > 1:
|
135
|
+
new_shape.append(cum_size)
|
127
136
|
|
128
|
-
|
137
|
+
new_shape = [grad.shape[0], *new_shape[::-1]]
|
138
|
+
new_grad = grad.reshape(new_shape)
|
129
139
|
if not split:
|
130
140
|
return new_grad
|
131
141
|
|
@@ -149,15 +159,17 @@ def beta_debias(beta, step):
|
|
149
159
|
return 1 - (1 - beta) / (1 - beta ** step)
|
150
160
|
|
151
161
|
|
162
|
+
def eps_sqrt(item, eps):
|
163
|
+
return item.sqrt().clamp(min=eps)
|
164
|
+
|
165
|
+
|
152
166
|
@decorator_knowngood
|
153
167
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
154
168
|
out: List[Optional[Tensor]]):
|
155
|
-
|
156
|
-
s32 = torch._foreach_mul(
|
157
|
-
|
158
|
-
denom =
|
159
|
-
denom = [d.clamp(min=eps) for d in denom]
|
160
|
-
copy_stochastic_list_(state, s32)
|
169
|
+
g32 = promote(grad)
|
170
|
+
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
171
|
+
|
172
|
+
denom = [eps_sqrt(d, eps) for d in s32]
|
161
173
|
|
162
174
|
if out[0] is None:
|
163
175
|
return denom
|
@@ -189,7 +201,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
189
201
|
|
190
202
|
@decorator_knowngood
|
191
203
|
def _compilable_exp_avg_(state, grad, beta):
|
192
|
-
lerped =
|
204
|
+
lerped = _lerp(state, grad, beta)
|
193
205
|
copy_stochastic_list_(grad, lerped)
|
194
206
|
|
195
207
|
|
@@ -240,29 +252,31 @@ def clean():
|
|
240
252
|
gc.collect()
|
241
253
|
|
242
254
|
|
243
|
-
def
|
255
|
+
def _ignore_warning(msg):
|
256
|
+
warnings.filterwarnings('ignore', f'.*{msg}.*')
|
257
|
+
|
258
|
+
|
259
|
+
def set_torch(benchmark_limit: int = 32):
|
244
260
|
cudnn.benchmark = True
|
245
261
|
cudnn.deterministic = False
|
262
|
+
cudnn.benchmark_limit = benchmark_limit
|
246
263
|
torch.use_deterministic_algorithms(False)
|
247
264
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
248
|
-
opt_einsum.enabled =
|
249
|
-
opt_einsum.strategy = "auto
|
265
|
+
opt_einsum.enabled = False
|
266
|
+
opt_einsum.strategy = "auto"
|
267
|
+
|
268
|
+
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
269
|
+
_ignore_warning(
|
270
|
+
'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
|
271
|
+
_ignore_warning(
|
272
|
+
'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
|
250
273
|
|
251
274
|
|
252
275
|
@decorator
|
253
276
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
254
|
-
"""
|
255
|
-
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
256
|
-
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
257
|
-
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
258
|
-
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
259
|
-
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
260
|
-
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
261
|
-
performance at all relative to UV^T, where USV^T = G is the SVD.
|
262
|
-
"""
|
263
277
|
assert len(G.shape) == 2
|
264
278
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
265
|
-
X = G.bfloat16
|
279
|
+
X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
|
266
280
|
X /= (X.norm() + eps) # ensure top singular value <= 1
|
267
281
|
if G.size(0) > G.size(1):
|
268
282
|
X = X.T
|
@@ -319,7 +333,7 @@ def nesterov_momentum(state, grad, beta):
|
|
319
333
|
|
320
334
|
@decorator_knowngood
|
321
335
|
def _compilable_nesterov_ema_(state, grad, beta):
|
322
|
-
ema32 =
|
336
|
+
ema32 = _lerp(state, grad, beta)
|
323
337
|
stochastic_add_(grad, ema32, 1)
|
324
338
|
|
325
339
|
|
@@ -330,12 +344,11 @@ def nesterov_ema(state, grad, beta):
|
|
330
344
|
return grad
|
331
345
|
|
332
346
|
|
347
|
+
@decorator_knowngood
|
333
348
|
def _compilable_grafting(magnitude, direction):
|
334
349
|
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
|
335
350
|
|
336
351
|
|
337
|
-
# mode in ("newtonschulz", "qr", "svd")
|
338
|
-
# scale_mode in ("none", "scale", "graft")
|
339
352
|
@decorator_knowngood
|
340
353
|
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
341
354
|
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
@@ -363,74 +376,84 @@ def _compilable_scatter_set(target, source, index):
|
|
363
376
|
target[:] = source.contiguous()[index].reshape_as(target)
|
364
377
|
|
365
378
|
|
366
|
-
|
379
|
+
#@decorator_knowngood
|
380
|
+
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
|
367
381
|
"""
|
368
382
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
369
|
-
followed by torch.linalg.qr decomposition.
|
383
|
+
followed by torch.linalg.qr decomposition, and updates exp_avg in-place from old to new eigenspace.
|
384
|
+
|
385
|
+
:param GG: List of accumulated gradient outer products.
|
386
|
+
:param Q: List of current eigenbases (updated in-place to Q_new).
|
387
|
+
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
370
388
|
"""
|
371
|
-
|
372
|
-
|
373
|
-
for m, o in zip(GG, Q):
|
374
|
-
if len(m) == 0:
|
375
|
-
matrix.append([])
|
376
|
-
orth_matrix.append([])
|
377
|
-
continue
|
378
|
-
if m.data.dtype != torch.float:
|
379
|
-
matrix.append(promote(m.data))
|
380
|
-
orth_matrix.append(promote(o.data))
|
381
|
-
else:
|
382
|
-
matrix.append(promote(m.data))
|
383
|
-
orth_matrix.append(promote(o.data))
|
389
|
+
if isinstance(Q, list) and not Q:
|
390
|
+
return
|
384
391
|
|
385
|
-
|
392
|
+
if exp_avg is not None and exp_avg.dim() != len(Q):
|
393
|
+
raise ValueError(f"exp_avg dim {exp_avg.dim()} does not match Q length {len(Q)}")
|
386
394
|
|
387
|
-
|
388
|
-
|
389
|
-
|
395
|
+
new_qs = []
|
396
|
+
|
397
|
+
for m, q in zip(GG, Q):
|
398
|
+
if m is None:
|
399
|
+
new_qs.append(None)
|
390
400
|
continue
|
391
401
|
|
392
|
-
|
393
|
-
|
402
|
+
m = promote(m.data)
|
403
|
+
q_old = promote(q.data)
|
404
|
+
|
405
|
+
tmp = m @ q_old
|
406
|
+
est_eig = torch.einsum('ij,ij->j', q_old, tmp)
|
394
407
|
sort_idx = torch.argsort(est_eig, descending=True)
|
395
|
-
indices.append(sort_idx)
|
396
|
-
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
|
397
408
|
|
398
|
-
|
399
|
-
|
400
|
-
|
409
|
+
tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
|
410
|
+
new_qs.append(tmp)
|
411
|
+
|
412
|
+
if exp_avg is None:
|
413
|
+
for q, q_new in zip(Q, new_qs):
|
414
|
+
copy_stochastic_(q, q_new)
|
415
|
+
return
|
416
|
+
|
417
|
+
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
418
|
+
in_str = einsum_base[:exp_avg.dim()]
|
419
|
+
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
420
|
+
|
421
|
+
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
422
|
+
if not from_shampoo:
|
423
|
+
return
|
424
|
+
|
425
|
+
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
|
426
|
+
out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
427
|
+
|
428
|
+
subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
|
429
|
+
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None])
|
430
|
+
copy_stochastic_(exp_avg, exp_avg_new)
|
431
|
+
|
432
|
+
for q, q_new in zip(Q, new_qs):
|
433
|
+
if q is not None:
|
434
|
+
copy_stochastic_(q, q_new)
|
401
435
|
|
402
436
|
|
403
437
|
def get_orthogonal_matrix(mat):
|
404
438
|
"""
|
405
439
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
406
440
|
"""
|
407
|
-
matrix = []
|
408
|
-
for m in mat:
|
409
|
-
if len(m) == 0:
|
410
|
-
matrix.append([])
|
411
|
-
continue
|
412
|
-
if m.data.dtype != torch.float:
|
413
|
-
float_data = False
|
414
|
-
original_type = m.data.dtype
|
415
|
-
original_device = m.data.device
|
416
|
-
matrix.append(promote(m.data))
|
417
|
-
else:
|
418
|
-
float_data = True
|
419
|
-
matrix.append(m.data)
|
420
441
|
|
421
442
|
final = []
|
422
|
-
for m in
|
423
|
-
if
|
424
|
-
final.append(
|
443
|
+
for m in mat:
|
444
|
+
if m is None:
|
445
|
+
final.append(None)
|
425
446
|
continue
|
426
447
|
|
448
|
+
m = promote(m.data)
|
449
|
+
|
427
450
|
device, dtype = m.device, m.dtype
|
428
451
|
for modifier in (None, torch.double, 'cpu'):
|
429
452
|
if modifier is not None:
|
430
453
|
m = m.to(modifier)
|
431
454
|
try:
|
432
|
-
|
433
|
-
|
455
|
+
eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
|
456
|
+
eigvec = eigvec.to(device=device, dtype=dtype)
|
434
457
|
break
|
435
458
|
except torch.OutOfMemoryError:
|
436
459
|
pass
|
@@ -440,9 +463,9 @@ def get_orthogonal_matrix(mat):
|
|
440
463
|
else:
|
441
464
|
raise RuntimeError("Failed to compute eigenvalues.")
|
442
465
|
|
443
|
-
|
466
|
+
eigvec = torch.flip(eigvec, [1])
|
444
467
|
|
445
|
-
final.append(
|
468
|
+
final.append(eigvec)
|
446
469
|
|
447
470
|
return final
|
448
471
|
|
@@ -452,7 +475,9 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
452
475
|
for x_, y_ in zip(x, y):
|
453
476
|
x32 = promote(x_)
|
454
477
|
y32 = promote(y_)
|
455
|
-
|
478
|
+
if x32.dtype != y32.dtype:
|
479
|
+
y32 = y32.to(x32.dtype)
|
480
|
+
copy_stochastic_(x_, x32 * (1 - a) + y32 * a)
|
456
481
|
|
457
482
|
|
458
483
|
def get_beta1(group):
|
@@ -467,7 +492,7 @@ def get_beta1(group):
|
|
467
492
|
|
468
493
|
|
469
494
|
def get_beta2(group):
|
470
|
-
if 'beta2_scale' in group:
|
495
|
+
if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
|
471
496
|
step = max(group.get("step", 1), 1)
|
472
497
|
return 1 - step ** -group['beta2_scale']
|
473
498
|
if 'betas' in group:
|
@@ -536,20 +561,32 @@ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
536
561
|
|
537
562
|
|
538
563
|
@decorator
|
539
|
-
def
|
564
|
+
def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
565
|
+
"""
|
566
|
+
Simplified by @francois-rozet in commit 704ccc4bab52429f945df421647ec82c54cdd65f
|
567
|
+
Re-commited due to faulty merge
|
568
|
+
"""
|
540
569
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
541
570
|
return
|
542
571
|
|
543
|
-
for idx,
|
544
|
-
if
|
572
|
+
for idx, m in enumerate(GG):
|
573
|
+
if not isinstance(m, Tensor):
|
545
574
|
continue
|
546
575
|
b = einsum_base[idx]
|
547
576
|
g0 = einsum_base[:grad.dim()]
|
548
577
|
g1 = g0.replace(b, b.upper())
|
549
578
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
550
|
-
|
579
|
+
stochastic_lerp_(m, outer_product, 1 - beta)
|
580
|
+
|
581
|
+
|
582
|
+
def tree_apply(fn):
|
583
|
+
def _fn(*args):
|
584
|
+
return tree_map(fn, *args)
|
585
|
+
|
586
|
+
return _fn
|
551
587
|
|
552
588
|
|
589
|
+
@tree_apply
|
553
590
|
def promote(x):
|
554
591
|
if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
|
555
592
|
return torch.float32
|
@@ -566,63 +603,85 @@ def min_dtype(xs: List[Tensor]):
|
|
566
603
|
return torch.float32
|
567
604
|
|
568
605
|
|
569
|
-
def update_preconditioner(grad, Q, GG,
|
606
|
+
def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d, beta, update_precond):
|
570
607
|
"""
|
571
608
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
572
609
|
"""
|
573
|
-
|
610
|
+
update_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
|
574
611
|
if update_precond:
|
575
|
-
get_orthogonal_matrix_QR(GG, Q,
|
612
|
+
get_orthogonal_matrix_QR(GG, Q, exp_avg)
|
576
613
|
|
577
614
|
|
578
|
-
def init_preconditioner(grad, state,
|
615
|
+
def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
579
616
|
"""
|
580
617
|
Initializes the preconditioner matrices (L and R in the paper).
|
581
618
|
"""
|
582
619
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
583
|
-
if grad.
|
584
|
-
if precondition_1d or grad.shape[0] > max_precond_dim:
|
585
|
-
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
586
|
-
else:
|
587
|
-
state['GG'].append([])
|
588
|
-
|
589
|
-
else:
|
620
|
+
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
590
621
|
for sh in grad.shape:
|
591
|
-
if sh > max_precond_dim:
|
592
|
-
|
622
|
+
if sh > max_precond_dim or sh == 1:
|
623
|
+
# via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
|
624
|
+
state['GG'].append(None)
|
593
625
|
else:
|
594
626
|
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
627
|
+
else:
|
628
|
+
state['GG'].append(None)
|
595
629
|
|
596
|
-
|
630
|
+
update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
|
597
631
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
598
632
|
|
599
633
|
|
600
634
|
@decorator
|
601
635
|
def project(grad, Q, back: bool):
|
602
636
|
"""
|
603
|
-
|
604
637
|
:param grad:
|
605
638
|
:param Q:
|
606
|
-
:param merge_dims:
|
607
|
-
:param max_precond_dim:
|
608
639
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
609
640
|
:return:
|
610
641
|
"""
|
611
642
|
param = einsum_base[:grad.dim()]
|
612
|
-
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if
|
643
|
+
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
613
644
|
if preconditioners:
|
614
645
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
615
|
-
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if
|
646
|
+
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
|
616
647
|
grad = out.to(grad.dtype)
|
617
648
|
return grad
|
618
649
|
|
619
650
|
|
651
|
+
def modify_closure(closure):
|
652
|
+
"""
|
653
|
+
Modifies the closure function to use create_graph=True in backward().
|
654
|
+
|
655
|
+
Args:
|
656
|
+
closure: The closure function passed to the optimizer.
|
657
|
+
|
658
|
+
Returns:
|
659
|
+
The return value of the modified closure.
|
660
|
+
"""
|
661
|
+
|
662
|
+
def patched_backward(self, *args, **kwargs):
|
663
|
+
kwargs['create_graph'] = True
|
664
|
+
return original_backward(self, *args, **kwargs)
|
665
|
+
|
666
|
+
original_backward = torch.Tensor.backward
|
667
|
+
|
668
|
+
with patch.object(torch.Tensor, 'backward', patched_backward):
|
669
|
+
return closure()
|
670
|
+
|
671
|
+
|
620
672
|
class StatefulOptimizer(torch.optim.Optimizer):
|
673
|
+
"""
|
674
|
+
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
675
|
+
Both `True` and `False` have some edge cases they don't support, so experiment with it.
|
676
|
+
The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
|
677
|
+
Further notice that both methods have different numerics outputs
|
678
|
+
"""
|
621
679
|
ema_decay: float = 0.001
|
622
680
|
compile_step: bool = False
|
623
681
|
hessian_approx: bool = False
|
624
682
|
precond_schedule: Union[Callable, float, None] = None
|
625
683
|
stochastic_schedule: bool = False
|
684
|
+
finite_differences: bool = False
|
626
685
|
|
627
686
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
628
687
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
@@ -735,38 +794,68 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
735
794
|
set_(self.state_(p)['param_ema'], p.data)
|
736
795
|
set_(p.data, ema_clone)
|
737
796
|
|
738
|
-
def
|
739
|
-
if self.precond_schedule is None:
|
740
|
-
self._is_preconditioning = False
|
741
|
-
else:
|
742
|
-
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
797
|
+
def _handle_closure(self, closure):
|
743
798
|
hessian_approx = self.hessian_approx and self._is_preconditioning
|
799
|
+
|
744
800
|
if closure is None:
|
745
801
|
if hessian_approx:
|
746
802
|
raise ValueError("Hessian approximation requires a closure.")
|
747
|
-
|
748
|
-
|
803
|
+
return None
|
804
|
+
|
805
|
+
if not hessian_approx:
|
749
806
|
with torch.enable_grad():
|
750
807
|
loss = closure()
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
808
|
+
return loss
|
809
|
+
|
810
|
+
if self.finite_differences:
|
811
|
+
with torch.enable_grad():
|
812
|
+
loss = closure() # closure without retain_graph=True
|
813
|
+
|
814
|
+
grads = []
|
815
|
+
for group in self.param_groups:
|
816
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
817
|
+
grads.append(g)
|
818
|
+
p.vector = torch.randn_like(p)
|
819
|
+
p.orig = p.data.clone()
|
820
|
+
stochastic_add_(p.data, p.vector, tiny_bf16)
|
821
|
+
else:
|
822
|
+
with torch.enable_grad():
|
823
|
+
loss = modify_closure(closure)
|
824
|
+
|
825
|
+
if self.finite_differences:
|
826
|
+
with torch.enable_grad():
|
827
|
+
closure()
|
828
|
+
|
829
|
+
for group in self.param_groups:
|
830
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
831
|
+
p.grad = grads.pop(0)
|
832
|
+
stochastic_add_(g, p.grad, -1)
|
833
|
+
p.hessian_vector = g
|
834
|
+
p.data.copy_(p.orig)
|
835
|
+
del p.orig
|
836
|
+
else:
|
837
|
+
for group in self.param_groups:
|
838
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
839
|
+
p.grad = g
|
840
|
+
params, grads = zip(*[x for group in self.param_groups for x in
|
841
|
+
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
842
|
+
vs = [torch.randn_like(p) for p in params]
|
843
|
+
with torch.enable_grad():
|
844
|
+
hvs = torch.autograd.grad(grads, params, vs)
|
845
|
+
|
846
|
+
for p, g, v, hv in zip(params, grads, vs, hvs):
|
847
|
+
p.hessian_vector = hv
|
848
|
+
p.grad = g
|
849
|
+
p.vector = v
|
850
|
+
|
851
|
+
return loss
|
852
|
+
|
853
|
+
def step(self, closure: Optional[Callable] = None):
|
854
|
+
if self.precond_schedule is None:
|
855
|
+
self._is_preconditioning = False
|
856
|
+
else:
|
857
|
+
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
858
|
+
loss = self._handle_closure(closure)
|
770
859
|
|
771
860
|
# we assume that parameters are constant and that there are no excessive recompiles
|
772
861
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
@@ -774,7 +863,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
774
863
|
group['is_preconditioning'] = self._is_preconditioning
|
775
864
|
self._step(group)
|
776
865
|
if self.use_ema:
|
777
|
-
self.ema_update(
|
866
|
+
self.ema_update()
|
778
867
|
|
779
868
|
return loss
|
780
869
|
|
@@ -784,12 +873,12 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
|
784
873
|
copy_stochastic_(t, s)
|
785
874
|
|
786
875
|
|
787
|
-
|
876
|
+
@decorator_knowngood
|
877
|
+
def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
788
878
|
ea32 = list(map(promote, state))
|
789
879
|
grad = list(map(promote, grad))
|
790
880
|
beta = promote(beta)
|
791
|
-
|
792
|
-
ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
|
881
|
+
stochastic_lerp_(ea32, grad, 1 - beta)
|
793
882
|
copy_stochastic_list_(state, ea32)
|
794
883
|
return ea32
|
795
884
|
|
@@ -801,9 +890,8 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
801
890
|
beta2 = beta_debias(beta2, step)
|
802
891
|
|
803
892
|
g32 = list(map(promote, grad))
|
804
|
-
|
805
|
-
|
806
|
-
denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
|
893
|
+
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
894
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
807
895
|
u32 = torch._foreach_div(exp_avg32, denom)
|
808
896
|
copy_stochastic_list_(grad, u32)
|
809
897
|
|
@@ -824,9 +912,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
824
912
|
beta2 = beta_debias(beta2, step)
|
825
913
|
|
826
914
|
u32, g32 = [list(map(promote, x)) for x in [update, grad]]
|
827
|
-
|
828
|
-
|
829
|
-
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
915
|
+
exp_avg32 = _lerp(exp_avg, u32, beta1)
|
916
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
|
830
917
|
u32 = torch._foreach_div(exp_avg32, denom)
|
831
918
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
832
919
|
|
@@ -836,7 +923,7 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
|
|
836
923
|
caution: bool):
|
837
924
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
838
925
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
839
|
-
|
926
|
+
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
840
927
|
|
841
928
|
|
842
929
|
@decorator_knowngood
|
@@ -846,17 +933,16 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
846
933
|
beta2 = beta_debias(beta2, step)
|
847
934
|
|
848
935
|
gp32 = list(map(promote, grad))
|
849
|
-
|
850
|
-
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
|
936
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, gp32, beta2, eps, [None])
|
851
937
|
gp32 = torch._foreach_div(gp32, denom)
|
852
|
-
gp32 =
|
853
|
-
|
938
|
+
gp32 = _lerp(exp_avg, gp32, beta1)
|
854
939
|
copy_stochastic_list_(grad, gp32)
|
855
940
|
|
856
941
|
|
857
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
942
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
943
|
+
eps: float = 1e-8):
|
858
944
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
859
|
-
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0]
|
945
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
860
946
|
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
861
947
|
return grad
|
862
948
|
|
@@ -864,23 +950,23 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
864
950
|
@decorator_knowngood
|
865
951
|
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
866
952
|
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
|
867
|
-
caution: bool):
|
953
|
+
caution: bool, eps: Tensor):
|
868
954
|
beta1 = beta_debias(beta1, step)
|
869
955
|
beta2 = beta_debias(beta2, step)
|
870
956
|
|
871
957
|
u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
|
872
|
-
|
873
|
-
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
958
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
|
874
959
|
u32 = torch._foreach_div(u32, denom)
|
875
|
-
u32 =
|
960
|
+
u32 = _lerp(exp_avg, u32, beta1)
|
876
961
|
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
877
962
|
|
878
963
|
|
879
964
|
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
880
|
-
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool
|
965
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
|
966
|
+
eps: float = 1e-8):
|
881
967
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
882
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
883
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
|
968
|
+
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
969
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
884
970
|
|
885
971
|
|
886
972
|
@decorator_knowngood
|
@@ -889,14 +975,11 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
889
975
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
890
976
|
|
891
977
|
beta1 = beta_debias(beta1, step)
|
892
|
-
denom =
|
893
|
-
|
894
|
-
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
895
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
978
|
+
denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
|
979
|
+
stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
|
896
980
|
|
897
981
|
beta2 = beta_debias(beta2, step + 1)
|
898
|
-
|
899
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
982
|
+
stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
|
900
983
|
|
901
984
|
|
902
985
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
@@ -906,27 +989,23 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
|
|
906
989
|
|
907
990
|
|
908
991
|
@decorator_knowngood
|
909
|
-
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
992
|
+
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
|
910
993
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
911
994
|
update = [e.clone() for e in exp_avg]
|
912
995
|
|
913
996
|
beta1 = beta_debias(beta1, step)
|
914
|
-
denom =
|
915
|
-
|
916
|
-
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
917
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
997
|
+
denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
|
998
|
+
stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
|
918
999
|
|
919
|
-
|
920
|
-
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
|
921
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
1000
|
+
stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
|
922
1001
|
|
923
1002
|
copy_stochastic_list_(grad, update)
|
924
1003
|
|
925
1004
|
|
926
|
-
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
1005
|
+
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps: float = 1e-8):
|
927
1006
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
928
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
929
|
-
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
1007
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1008
|
+
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps)
|
930
1009
|
return grad
|
931
1010
|
|
932
1011
|
|
@@ -1005,7 +1084,6 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1005
1084
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
1006
1085
|
"""
|
1007
1086
|
letters = string.ascii_lowercase + string.ascii_uppercase
|
1008
|
-
|
1009
1087
|
dtype = dtype if dtype is not None else t.dtype
|
1010
1088
|
shape = t.shape
|
1011
1089
|
|
@@ -1049,11 +1127,9 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1049
1127
|
piece1A.append(letters[i])
|
1050
1128
|
piece2A = piece2A + letters[i]
|
1051
1129
|
piece3A = piece3A + letters[i]
|
1052
|
-
|
1053
1130
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1054
1131
|
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1055
1132
|
exprGs.append(subscripts)
|
1056
|
-
|
1057
1133
|
piece1P.append(letters[i + 13])
|
1058
1134
|
piece2P.append(letters[i + 13])
|
1059
1135
|
piece3P = piece3P + letters[i + 13]
|
@@ -1061,16 +1137,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1061
1137
|
else:
|
1062
1138
|
# use triangular matrix as preconditioner for this dim
|
1063
1139
|
Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
|
1064
|
-
|
1065
1140
|
piece1A.append(letters[i] + letters[i + 13])
|
1066
1141
|
piece2A = piece2A + letters[i + 13]
|
1067
1142
|
piece3A = piece3A + letters[i]
|
1068
|
-
|
1069
1143
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1070
1144
|
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1071
1145
|
subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
|
1072
1146
|
exprGs.append(subscripts)
|
1073
|
-
|
1074
1147
|
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1075
1148
|
piece1P.append(a + b)
|
1076
1149
|
piece2P.append(a + c)
|
@@ -1091,7 +1164,7 @@ def psgd_balance_Q(Q_in):
|
|
1091
1164
|
|
1092
1165
|
|
1093
1166
|
def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
|
1094
|
-
eps = scalar_guard(math.sqrt(torch.finfo(
|
1167
|
+
eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
|
1095
1168
|
eps *= G.norm() / G.numel()
|
1096
1169
|
G = G + torch.randn_like(G) * eps
|
1097
1170
|
md = min_dtype(Q + [G])
|
@@ -1117,9 +1190,7 @@ def psgd_lb(A, max_abs):
|
|
1117
1190
|
A /= max_abs
|
1118
1191
|
a0 = torch.einsum('ij,ij->j', A, A)
|
1119
1192
|
i = torch.argmax(a0)
|
1120
|
-
|
1121
1193
|
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1122
|
-
|
1123
1194
|
x = torch.einsum('i,ij->j', x, A)
|
1124
1195
|
x /= x.norm()
|
1125
1196
|
x = torch.einsum('j,kj->k', x, A)
|
@@ -1132,15 +1203,12 @@ def psgd_lb(A, max_abs):
|
|
1132
1203
|
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
1133
1204
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1134
1205
|
exprA, exprGs, _ = exprs
|
1135
|
-
|
1136
1206
|
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
|
1137
1207
|
|
1138
1208
|
for q, exprG, o in zip(Q, exprGs, oq):
|
1139
1209
|
term1 = promote(torch.einsum(exprG, A, A))
|
1140
1210
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1141
|
-
|
1142
1211
|
term1, term2 = term1 - term2, term1 + term2
|
1143
|
-
|
1144
1212
|
term1 *= precond_lr
|
1145
1213
|
norm = term2.norm(float('inf'))
|
1146
1214
|
if q.dim() < 2:
|
@@ -1256,8 +1324,8 @@ def identity(x):
|
|
1256
1324
|
|
1257
1325
|
@decorator_knowngood
|
1258
1326
|
def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1259
|
-
ema32 =
|
1260
|
-
|
1327
|
+
ema32 = _lerp(ema, p, ema_decay)
|
1328
|
+
_lerp(p, ema32, 1 - weight_decay)
|
1261
1329
|
|
1262
1330
|
|
1263
1331
|
def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
@@ -1267,10 +1335,10 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
|
1267
1335
|
|
1268
1336
|
|
1269
1337
|
@decorator_knowngood
|
1270
|
-
def _compilable_l1_weight_decay_to_ema_(p, ema,
|
1271
|
-
ema32 =
|
1338
|
+
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1339
|
+
ema32 = _lerp(ema, p, ema_decay)
|
1272
1340
|
for p_, e_ in zip(p, ema32):
|
1273
|
-
p32 = promote(
|
1341
|
+
p32 = promote(p_)
|
1274
1342
|
p32 = p32 + (p32 - e_).sign() * weight_decay
|
1275
1343
|
copy_stochastic_(p_, p32)
|
1276
1344
|
|
@@ -1447,7 +1515,6 @@ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps:
|
|
1447
1515
|
|
1448
1516
|
if graft:
|
1449
1517
|
out = _compilable_grafting(g, out)
|
1450
|
-
|
1451
1518
|
copy_stochastic_(g, out)
|
1452
1519
|
|
1453
1520
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.6.1
|
4
4
|
Summary: Efficient optimizers
|
5
|
-
Home-page: https://github.com/
|
6
|
-
Author:
|
5
|
+
Home-page: https://github.com/HomebrewML/HeavyBall
|
6
|
+
Author: HeavyBall Authors
|
7
7
|
Author-email: github.heavyball@nestler.sh
|
8
8
|
License: BSD
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
300
300
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
301
301
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
302
302
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
303
|
-
|
303
|
+
correct_bias: bool = True, warmup_steps: int = 1,
|
304
304
|
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
305
305
|
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
306
306
|
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
|
|
324
324
|
* **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
|
325
325
|
* **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
|
326
326
|
* **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
|
327
|
-
* **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
|
328
327
|
* **`correct_bias`**: Enables/disables bias correction for the running averages.
|
329
328
|
* **`warmup_steps`**: Number of steps for linear learning rate warmup.
|
330
329
|
* **`split`**: Whether to split large dimensions when forming the preconditioner.
|
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
|
|
931
930
|
* **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
|
932
931
|
`compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
|
933
932
|
configurations.
|
934
|
-
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
+
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
+
heavyball/utils.py,sha256=iQxSQjw_sgJp4AvX71VdTJxJ_20Tdu7W2tdrYu5q2EI,55808
|
4
|
+
heavyball-1.6.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.6.1.dist-info/METADATA,sha256=yFMCDJPpD5jVOFtL4l_pM3jTw3_ZizeTSQ_ugVHIWKM,43441
|
6
|
+
heavyball-1.6.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.6.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.6.1.dist-info/RECORD,,
|
heavyball-1.5.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
|
2
|
-
heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
|
3
|
-
heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
|
4
|
-
heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
|
6
|
-
heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|