heavyball 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +63 -20
- heavyball/chainable.py +15 -12
- heavyball/utils.py +245 -173
- {heavyball-1.5.3.dist-info → heavyball-1.6.0.dist-info}/METADATA +4 -6
- heavyball-1.6.0.dist-info/RECORD +8 -0
- heavyball-1.5.3.dist-info/RECORD +0 -8
- {heavyball-1.5.3.dist-info → heavyball-1.6.0.dist-info}/LICENSE +0 -0
- {heavyball-1.5.3.dist-info → heavyball-1.6.0.dist-info}/WHEEL +0 -0
- {heavyball-1.5.3.dist-info → heavyball-1.6.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
|
|
104
104
|
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
105
105
|
https://arxiv.org/abs/2409.11321
|
106
106
|
https://github.com/nikhilvyas/SOAP
|
107
|
-
|
108
|
-
ScheduleFree:
|
109
|
-
The Road Less Scheduled
|
110
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
111
|
-
https://arxiv.org/abs/2405.15682
|
112
|
-
https://github.com/facebookresearch/schedule_free
|
113
107
|
"""
|
114
108
|
use_precond_schedule: bool = False
|
115
109
|
|
116
110
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
117
111
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
118
112
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
113
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
114
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
115
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
116
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
117
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
124
118
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
125
119
|
|
126
120
|
defaults = locals()
|
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
|
|
137
131
|
C.scale_by_soap)
|
138
132
|
|
139
133
|
|
134
|
+
class ForeachSignLaProp(C.BaseOpt):
|
135
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
136
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
137
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
138
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
139
|
+
defaults = locals()
|
140
|
+
defaults.pop("self")
|
141
|
+
params = defaults.pop("params")
|
142
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
|
143
|
+
|
144
|
+
|
145
|
+
class ForeachSOLP(C.BaseOpt):
|
146
|
+
"""
|
147
|
+
ForeachSOLP
|
148
|
+
|
149
|
+
Sources:
|
150
|
+
Baseline SOAP:
|
151
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
152
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
153
|
+
https://arxiv.org/abs/2409.11321
|
154
|
+
https://github.com/nikhilvyas/SOAP
|
155
|
+
"""
|
156
|
+
use_precond_schedule: bool = False
|
157
|
+
|
158
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
159
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
160
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
161
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
162
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
163
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
164
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
165
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
166
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
167
|
+
|
168
|
+
defaults = locals()
|
169
|
+
defaults.pop("self")
|
170
|
+
params = defaults.pop("params")
|
171
|
+
|
172
|
+
if use_precond_schedule:
|
173
|
+
del defaults['precondition_frequency']
|
174
|
+
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
175
|
+
else:
|
176
|
+
del defaults['precond_scheduler']
|
177
|
+
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
178
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
179
|
+
functools.partial(C.scale_by_soap, inner='laprop'))
|
180
|
+
|
181
|
+
|
140
182
|
class PaLMForeachSOAP(ForeachSOAP):
|
141
183
|
use_precond_schedule: bool = False
|
142
184
|
palm: bool = True
|
@@ -163,7 +205,6 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
205
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
206
|
|
165
207
|
|
166
|
-
|
167
208
|
class LaPropOrtho(C.BaseOpt):
|
168
209
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
169
210
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
@@ -172,8 +213,8 @@ class LaPropOrtho(C.BaseOpt):
|
|
172
213
|
defaults = locals()
|
173
214
|
defaults.pop("self")
|
174
215
|
params = defaults.pop("params")
|
175
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
176
|
-
C.
|
216
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
|
217
|
+
C.orthogonalize_grad_to_param)
|
177
218
|
|
178
219
|
|
179
220
|
class ForeachPSGDKron(C.BaseOpt):
|
@@ -191,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
191
232
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
192
233
|
momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
|
193
234
|
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
194
|
-
stochastic_schedule: bool =
|
235
|
+
stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
|
195
236
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
196
237
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
197
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
238
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
198
239
|
# expert parameters
|
199
240
|
precond_init_scale=1.0, precond_lr=0.1):
|
200
241
|
defaults = locals()
|
@@ -251,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
|
|
251
292
|
CachedPSGDKron = ForeachCachedPSGDKron
|
252
293
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
253
294
|
Muon = ForeachMuon
|
295
|
+
SignLaProp = ForeachSignLaProp
|
254
296
|
|
255
297
|
__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
256
298
|
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
257
|
-
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
258
|
-
|
259
|
-
"
|
260
|
-
"
|
299
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
|
300
|
+
"ForeachAdamW", "ForeachSFAdamW",
|
301
|
+
"ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
|
302
|
+
"ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
|
303
|
+
'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
|
heavyball/chainable.py
CHANGED
@@ -127,7 +127,7 @@ def zero_guard(*names):
|
|
127
127
|
|
128
128
|
|
129
129
|
def copy_guard(index, *names):
|
130
|
-
return functools.partial(CopyGuard, index=index, names=names
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names)
|
131
131
|
|
132
132
|
|
133
133
|
def general_guard(*names, init_fn, skip_first: bool = True):
|
@@ -222,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
222
222
|
@no_state
|
223
223
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
224
224
|
if group['step'] == 1:
|
225
|
-
utils.
|
225
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
226
226
|
raise SkipUpdate
|
227
227
|
|
228
228
|
if group['step'] == 2:
|
229
229
|
update = utils.promote(update)
|
230
230
|
easq = utils.promote(exp_avg_sq)
|
231
231
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
232
|
-
utils.
|
232
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
233
|
+
group['eps'])
|
233
234
|
raise SkipUpdate
|
234
235
|
|
235
236
|
utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
@@ -241,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
241
242
|
@no_state
|
242
243
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
243
244
|
if group['step'] == 1:
|
244
|
-
utils.
|
245
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
245
246
|
raise SkipUpdate
|
246
247
|
|
247
248
|
if group['step'] == 2:
|
248
249
|
update = utils.promote(update)
|
249
250
|
easq = utils.promote(exp_avg_sq)
|
250
251
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
251
|
-
utils.
|
252
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
253
|
+
group['eps'])
|
252
254
|
raise SkipUpdate
|
253
255
|
|
254
256
|
return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
|
255
257
|
|
256
258
|
|
257
|
-
def _init_soap(state, group, update, grad, param):
|
258
|
-
utils.init_preconditioner(grad, state,
|
259
|
+
def _init_soap(state, group, update, grad, param, inner: str = ''):
|
260
|
+
utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
|
259
261
|
|
260
262
|
|
261
263
|
def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
@@ -343,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
|
|
343
345
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
344
346
|
@no_state
|
345
347
|
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
|
346
|
-
update = utils.promote(update)
|
348
|
+
update = utils.promote(update) # Promote to highest precision if needed
|
347
349
|
|
348
350
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
349
351
|
fn = _optim_fns[inner]
|
350
|
-
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'],
|
352
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
|
353
|
+
group['eps'])
|
351
354
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
352
355
|
|
353
|
-
for u, q, gg,
|
354
|
-
utils.update_preconditioner(u, q, gg,
|
356
|
+
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
357
|
+
utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
|
355
358
|
utils.beta_debias(group['shampoo_beta'], group['step']),
|
356
359
|
group['is_preconditioning'])
|
357
360
|
return precond
|
@@ -503,7 +506,7 @@ def create_branch(branches: List[List[callable]], merge_fn: callable):
|
|
503
506
|
def _branch(state, group, update, grad, param):
|
504
507
|
outputs = []
|
505
508
|
for branch in branches:
|
506
|
-
branch_update = [torch.clone(
|
509
|
+
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
507
510
|
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
508
511
|
if skip_update:
|
509
512
|
raise ValueError("Branches should not skip updates")
|
heavyball/utils.py
CHANGED
@@ -1,24 +1,40 @@
|
|
1
|
+
import copy
|
1
2
|
import functools
|
2
3
|
import gc
|
4
|
+
import inspect
|
3
5
|
import math
|
4
6
|
import random
|
5
7
|
import string
|
8
|
+
import sys
|
9
|
+
import time
|
6
10
|
import warnings
|
11
|
+
from datetime import datetime
|
7
12
|
from typing import List, Optional, Tuple, Callable, Union
|
13
|
+
from unittest.mock import patch
|
8
14
|
|
15
|
+
import hyperopt
|
9
16
|
import numpy as np
|
10
17
|
import torch
|
11
18
|
from torch import Tensor
|
19
|
+
from torch._dynamo import config
|
12
20
|
from torch._dynamo.exc import TorchDynamoException
|
13
21
|
from torch.backends import cudnn, opt_einsum
|
14
22
|
from torch.utils._pytree import tree_map
|
15
23
|
|
24
|
+
config.cache_size_limit = 2 ** 16
|
25
|
+
|
26
|
+
np.warnings = warnings
|
27
|
+
|
16
28
|
compile_mode = "max-autotune-no-cudagraphs"
|
17
29
|
dynamic = False
|
18
30
|
compile_mode_recommended_to_none = None
|
19
31
|
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
20
32
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
21
33
|
|
34
|
+
base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
|
35
|
+
'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
|
36
|
+
'weight_decay': 1e-4}
|
37
|
+
|
22
38
|
|
23
39
|
def decorator(func):
|
24
40
|
compiled = None
|
@@ -51,7 +67,7 @@ def decorator_knowngood(func: Callable):
|
|
51
67
|
return _fn
|
52
68
|
|
53
69
|
|
54
|
-
einsum_base = string.ascii_lowercase
|
70
|
+
einsum_base = string.ascii_lowercase
|
55
71
|
|
56
72
|
|
57
73
|
@decorator_knowngood
|
@@ -103,29 +119,29 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
103
119
|
but we want to merge conv kernels into fan-in or at least merge the kernel
|
104
120
|
so, [128, 64, 3, 3] should result in [128, 576] or [128, 64, 9] instead of [73728] or [8192, 3, 3] the baseline
|
105
121
|
would've done
|
122
|
+
|
123
|
+
By @francois-rozet (commit: 68cde41eaf7e73b4c46eacb6a944865dcc081f1d), re-commited due to faulty merge
|
106
124
|
"""
|
107
|
-
shape = grad.shape
|
108
125
|
new_shape = []
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
curr_shape = sh
|
126
|
+
cum_size = 1
|
127
|
+
|
128
|
+
for s in grad.shape[1:][::-1]:
|
129
|
+
temp_size = cum_size * s
|
130
|
+
if temp_size > max_precond_dim:
|
131
|
+
if cum_size > 1:
|
132
|
+
new_shape.append(cum_size)
|
133
|
+
cum_size = s
|
118
134
|
else:
|
119
|
-
new_shape.append(
|
120
|
-
|
135
|
+
new_shape.append(s)
|
136
|
+
cum_size = 1
|
121
137
|
else:
|
122
|
-
|
123
|
-
new_shape = [*shape[:1], *new_shape[::-1]]
|
138
|
+
cum_size = temp_size
|
124
139
|
|
125
|
-
if
|
126
|
-
new_shape.append(
|
140
|
+
if cum_size > 1:
|
141
|
+
new_shape.append(cum_size)
|
127
142
|
|
128
|
-
|
143
|
+
new_shape = [grad.shape[0], *new_shape[::-1]]
|
144
|
+
new_grad = grad.reshape(new_shape)
|
129
145
|
if not split:
|
130
146
|
return new_grad
|
131
147
|
|
@@ -153,12 +169,11 @@ def beta_debias(beta, step):
|
|
153
169
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
154
170
|
out: List[Optional[Tensor]]):
|
155
171
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
156
|
-
s32 =
|
157
|
-
s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
|
158
|
-
denom = torch._foreach_sqrt(s32)
|
159
|
-
denom = [d.clamp(min=eps) for d in denom]
|
172
|
+
s32 = [s * beta2 + g * g * (1 - beta2) for s, g in zip(s32, g32)]
|
160
173
|
copy_stochastic_list_(state, s32)
|
161
174
|
|
175
|
+
denom = [d.sqrt().clamp(min=eps) for d in s32]
|
176
|
+
|
162
177
|
if out[0] is None:
|
163
178
|
return denom
|
164
179
|
|
@@ -174,7 +189,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
174
189
|
|
175
190
|
@decorator_knowngood
|
176
191
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
177
|
-
g32 = promote
|
192
|
+
g32 = list(map(promote, grad))
|
178
193
|
denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
|
179
194
|
out = torch._foreach_div(g32, denom)
|
180
195
|
copy_stochastic_list_(grad, out)
|
@@ -189,7 +204,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
189
204
|
|
190
205
|
@decorator_knowngood
|
191
206
|
def _compilable_exp_avg_(state, grad, beta):
|
192
|
-
lerped =
|
207
|
+
lerped = _lerp(state, grad, beta)
|
193
208
|
copy_stochastic_list_(grad, lerped)
|
194
209
|
|
195
210
|
|
@@ -240,29 +255,31 @@ def clean():
|
|
240
255
|
gc.collect()
|
241
256
|
|
242
257
|
|
243
|
-
def
|
258
|
+
def _ignore_warning(msg):
|
259
|
+
warnings.filterwarnings('ignore', f'.*{msg}.*')
|
260
|
+
|
261
|
+
|
262
|
+
def set_torch(benchmark_limit: int = 32):
|
244
263
|
cudnn.benchmark = True
|
245
264
|
cudnn.deterministic = False
|
265
|
+
cudnn.benchmark_limit = benchmark_limit
|
246
266
|
torch.use_deterministic_algorithms(False)
|
247
267
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
248
268
|
opt_einsum.enabled = True
|
249
|
-
opt_einsum.strategy = "
|
269
|
+
opt_einsum.strategy = "dp"
|
270
|
+
|
271
|
+
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
272
|
+
_ignore_warning(
|
273
|
+
'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
|
274
|
+
_ignore_warning(
|
275
|
+
'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
|
250
276
|
|
251
277
|
|
252
278
|
@decorator
|
253
279
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
254
|
-
"""
|
255
|
-
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
256
|
-
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
257
|
-
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
258
|
-
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
259
|
-
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
260
|
-
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
261
|
-
performance at all relative to UV^T, where USV^T = G is the SVD.
|
262
|
-
"""
|
263
280
|
assert len(G.shape) == 2
|
264
281
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
265
|
-
X = G.bfloat16
|
282
|
+
X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
|
266
283
|
X /= (X.norm() + eps) # ensure top singular value <= 1
|
267
284
|
if G.size(0) > G.size(1):
|
268
285
|
X = X.T
|
@@ -319,7 +336,7 @@ def nesterov_momentum(state, grad, beta):
|
|
319
336
|
|
320
337
|
@decorator_knowngood
|
321
338
|
def _compilable_nesterov_ema_(state, grad, beta):
|
322
|
-
ema32 =
|
339
|
+
ema32 = _lerp(state, grad, beta)
|
323
340
|
stochastic_add_(grad, ema32, 1)
|
324
341
|
|
325
342
|
|
@@ -330,12 +347,11 @@ def nesterov_ema(state, grad, beta):
|
|
330
347
|
return grad
|
331
348
|
|
332
349
|
|
350
|
+
@decorator_knowngood
|
333
351
|
def _compilable_grafting(magnitude, direction):
|
334
352
|
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
|
335
353
|
|
336
354
|
|
337
|
-
# mode in ("newtonschulz", "qr", "svd")
|
338
|
-
# scale_mode in ("none", "scale", "graft")
|
339
355
|
@decorator_knowngood
|
340
356
|
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
341
357
|
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
@@ -363,74 +379,82 @@ def _compilable_scatter_set(target, source, index):
|
|
363
379
|
target[:] = source.contiguous()[index].reshape_as(target)
|
364
380
|
|
365
381
|
|
366
|
-
|
382
|
+
@decorator_knowngood
|
383
|
+
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
|
367
384
|
"""
|
368
385
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
369
|
-
followed by torch.linalg.qr decomposition.
|
386
|
+
followed by torch.linalg.qr decomposition, and updates exp_avg in-place from old to new eigenspace.
|
387
|
+
|
388
|
+
:param GG: List of accumulated gradient outer products.
|
389
|
+
:param Q: List of current eigenbases (updated in-place to Q_new).
|
390
|
+
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
370
391
|
"""
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
orth_matrix.append([])
|
377
|
-
continue
|
378
|
-
if m.data.dtype != torch.float:
|
379
|
-
matrix.append(promote(m.data))
|
380
|
-
orth_matrix.append(promote(o.data))
|
381
|
-
else:
|
382
|
-
matrix.append(promote(m.data))
|
383
|
-
orth_matrix.append(promote(o.data))
|
392
|
+
if isinstance(Q, list) and not Q:
|
393
|
+
return
|
394
|
+
|
395
|
+
if exp_avg is not None and exp_avg.dim() != len(Q):
|
396
|
+
raise ValueError(f"exp_avg dim {exp_avg.dim()} does not match Q length {len(Q)}")
|
384
397
|
|
385
|
-
|
398
|
+
new_qs = []
|
386
399
|
|
387
|
-
for
|
400
|
+
for m, q in zip(GG, Q):
|
388
401
|
if len(m) == 0:
|
389
|
-
indices.append(None)
|
390
402
|
continue
|
391
403
|
|
392
|
-
|
393
|
-
|
404
|
+
m = promote(m.data)
|
405
|
+
q_old = promote(q.data)
|
406
|
+
|
407
|
+
tmp = m @ q_old
|
408
|
+
est_eig = torch.einsum('ij,ij->j', q_old, tmp)
|
394
409
|
sort_idx = torch.argsort(est_eig, descending=True)
|
395
|
-
indices.append(sort_idx)
|
396
|
-
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
|
397
410
|
|
398
|
-
|
399
|
-
|
400
|
-
|
411
|
+
tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
|
412
|
+
new_qs.append(tmp)
|
413
|
+
|
414
|
+
if exp_avg is None:
|
415
|
+
for q, q_new in zip(Q, new_qs):
|
416
|
+
copy_stochastic_(q, q_new)
|
417
|
+
return
|
418
|
+
|
419
|
+
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
420
|
+
in_str = einsum_base[:exp_avg.dim()]
|
421
|
+
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
422
|
+
|
423
|
+
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if len(m) > 0])
|
424
|
+
if not from_shampoo:
|
425
|
+
return
|
426
|
+
|
427
|
+
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if len(m) > 0])
|
428
|
+
out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
429
|
+
|
430
|
+
subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
|
431
|
+
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
|
432
|
+
copy_stochastic_(exp_avg, exp_avg_new)
|
433
|
+
|
434
|
+
for q, q_new in zip(Q, new_qs):
|
435
|
+
copy_stochastic_(q, q_new)
|
401
436
|
|
402
437
|
|
403
438
|
def get_orthogonal_matrix(mat):
|
404
439
|
"""
|
405
440
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
406
441
|
"""
|
407
|
-
matrix = []
|
408
|
-
for m in mat:
|
409
|
-
if len(m) == 0:
|
410
|
-
matrix.append([])
|
411
|
-
continue
|
412
|
-
if m.data.dtype != torch.float:
|
413
|
-
float_data = False
|
414
|
-
original_type = m.data.dtype
|
415
|
-
original_device = m.data.device
|
416
|
-
matrix.append(promote(m.data))
|
417
|
-
else:
|
418
|
-
float_data = True
|
419
|
-
matrix.append(m.data)
|
420
442
|
|
421
443
|
final = []
|
422
|
-
for m in
|
444
|
+
for m in mat:
|
423
445
|
if len(m) == 0:
|
424
446
|
final.append([])
|
425
447
|
continue
|
426
448
|
|
449
|
+
m = promote(m.data)
|
450
|
+
|
427
451
|
device, dtype = m.device, m.dtype
|
428
452
|
for modifier in (None, torch.double, 'cpu'):
|
429
453
|
if modifier is not None:
|
430
454
|
m = m.to(modifier)
|
431
455
|
try:
|
432
|
-
|
433
|
-
|
456
|
+
eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
|
457
|
+
eigvec = eigvec.to(device=device, dtype=dtype)
|
434
458
|
break
|
435
459
|
except torch.OutOfMemoryError:
|
436
460
|
pass
|
@@ -440,9 +464,9 @@ def get_orthogonal_matrix(mat):
|
|
440
464
|
else:
|
441
465
|
raise RuntimeError("Failed to compute eigenvalues.")
|
442
466
|
|
443
|
-
|
467
|
+
eigvec = torch.flip(eigvec, [1])
|
444
468
|
|
445
|
-
final.append(
|
469
|
+
final.append(eigvec)
|
446
470
|
|
447
471
|
return final
|
448
472
|
|
@@ -467,7 +491,7 @@ def get_beta1(group):
|
|
467
491
|
|
468
492
|
|
469
493
|
def get_beta2(group):
|
470
|
-
if 'beta2_scale' in group:
|
494
|
+
if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
|
471
495
|
step = max(group.get("step", 1), 1)
|
472
496
|
return 1 - step ** -group['beta2_scale']
|
473
497
|
if 'betas' in group:
|
@@ -536,20 +560,32 @@ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
|
536
560
|
|
537
561
|
|
538
562
|
@decorator
|
539
|
-
def
|
563
|
+
def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
564
|
+
"""
|
565
|
+
Simplified by @francois-rozet in commit 704ccc4bab52429f945df421647ec82c54cdd65f
|
566
|
+
Re-commited due to faulty merge
|
567
|
+
"""
|
540
568
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
541
569
|
return
|
542
570
|
|
543
|
-
for idx,
|
544
|
-
if
|
571
|
+
for idx, m in enumerate(GG):
|
572
|
+
if not isinstance(m, Tensor):
|
545
573
|
continue
|
546
574
|
b = einsum_base[idx]
|
547
575
|
g0 = einsum_base[:grad.dim()]
|
548
576
|
g1 = g0.replace(b, b.upper())
|
549
577
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
550
|
-
|
578
|
+
m.lerp_(outer_product, 1 - beta)
|
579
|
+
|
580
|
+
|
581
|
+
def tree_apply(fn):
|
582
|
+
def _fn(*args):
|
583
|
+
return tree_map(fn, *args)
|
584
|
+
|
585
|
+
return _fn
|
551
586
|
|
552
587
|
|
588
|
+
@tree_apply
|
553
589
|
def promote(x):
|
554
590
|
if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
|
555
591
|
return torch.float32
|
@@ -566,45 +602,38 @@ def min_dtype(xs: List[Tensor]):
|
|
566
602
|
return torch.float32
|
567
603
|
|
568
604
|
|
569
|
-
def update_preconditioner(grad, Q, GG,
|
605
|
+
def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d, beta, update_precond):
|
570
606
|
"""
|
571
607
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
572
608
|
"""
|
573
|
-
|
609
|
+
update_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
|
574
610
|
if update_precond:
|
575
|
-
get_orthogonal_matrix_QR(GG, Q,
|
611
|
+
get_orthogonal_matrix_QR(GG, Q, exp_avg)
|
576
612
|
|
577
613
|
|
578
|
-
def init_preconditioner(grad, state,
|
614
|
+
def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
579
615
|
"""
|
580
616
|
Initializes the preconditioner matrices (L and R in the paper).
|
581
617
|
"""
|
582
618
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
583
|
-
if grad.
|
584
|
-
if precondition_1d or grad.shape[0] > max_precond_dim:
|
585
|
-
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
586
|
-
else:
|
587
|
-
state['GG'].append([])
|
588
|
-
|
589
|
-
else:
|
619
|
+
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
590
620
|
for sh in grad.shape:
|
591
621
|
if sh > max_precond_dim:
|
592
|
-
state['GG'].append(
|
622
|
+
state['GG'].append(None)
|
593
623
|
else:
|
594
624
|
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
625
|
+
else:
|
626
|
+
state['GG'].append(None)
|
595
627
|
|
596
|
-
|
628
|
+
update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
|
597
629
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
598
630
|
|
599
631
|
|
600
632
|
@decorator
|
601
633
|
def project(grad, Q, back: bool):
|
602
634
|
"""
|
603
|
-
|
604
635
|
:param grad:
|
605
636
|
:param Q:
|
606
|
-
:param merge_dims:
|
607
|
-
:param max_precond_dim:
|
608
637
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
609
638
|
:return:
|
610
639
|
"""
|
@@ -617,12 +646,40 @@ def project(grad, Q, back: bool):
|
|
617
646
|
return grad
|
618
647
|
|
619
648
|
|
649
|
+
def modify_closure(closure):
|
650
|
+
"""
|
651
|
+
Modifies the closure function to use create_graph=True in backward().
|
652
|
+
|
653
|
+
Args:
|
654
|
+
closure: The closure function passed to the optimizer.
|
655
|
+
|
656
|
+
Returns:
|
657
|
+
The return value of the modified closure.
|
658
|
+
"""
|
659
|
+
|
660
|
+
def patched_backward(self, *args, **kwargs):
|
661
|
+
kwargs['create_graph'] = True
|
662
|
+
return original_backward(self, *args, **kwargs)
|
663
|
+
|
664
|
+
original_backward = torch.Tensor.backward
|
665
|
+
|
666
|
+
with patch.object(torch.Tensor, 'backward', patched_backward):
|
667
|
+
return closure()
|
668
|
+
|
669
|
+
|
620
670
|
class StatefulOptimizer(torch.optim.Optimizer):
|
671
|
+
"""
|
672
|
+
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
673
|
+
Both `True` and `False` have some edge cases they don't support, so experiment with it.
|
674
|
+
The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
|
675
|
+
Further notice that both methods have different numerics outputs
|
676
|
+
"""
|
621
677
|
ema_decay: float = 0.001
|
622
678
|
compile_step: bool = False
|
623
679
|
hessian_approx: bool = False
|
624
680
|
precond_schedule: Union[Callable, float, None] = None
|
625
681
|
stochastic_schedule: bool = False
|
682
|
+
finite_differences: bool = False
|
626
683
|
|
627
684
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
628
685
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
@@ -735,38 +792,68 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
735
792
|
set_(self.state_(p)['param_ema'], p.data)
|
736
793
|
set_(p.data, ema_clone)
|
737
794
|
|
738
|
-
def
|
739
|
-
if self.precond_schedule is None:
|
740
|
-
self._is_preconditioning = False
|
741
|
-
else:
|
742
|
-
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
795
|
+
def _handle_closure(self, closure):
|
743
796
|
hessian_approx = self.hessian_approx and self._is_preconditioning
|
797
|
+
|
744
798
|
if closure is None:
|
745
799
|
if hessian_approx:
|
746
800
|
raise ValueError("Hessian approximation requires a closure.")
|
747
|
-
|
748
|
-
|
801
|
+
return None
|
802
|
+
|
803
|
+
if not hessian_approx:
|
749
804
|
with torch.enable_grad():
|
750
805
|
loss = closure()
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
806
|
+
return loss
|
807
|
+
|
808
|
+
if self.finite_differences:
|
809
|
+
with torch.enable_grad():
|
810
|
+
loss = closure() # closure without retain_graph=True
|
811
|
+
|
812
|
+
grads = []
|
813
|
+
for group in self.param_groups:
|
814
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
815
|
+
grads.append(g)
|
816
|
+
p.vector = torch.randn_like(p)
|
817
|
+
p.orig = p.data.clone()
|
818
|
+
stochastic_add_(p.data, p.vector, tiny_bf16)
|
819
|
+
else:
|
820
|
+
with torch.enable_grad():
|
821
|
+
loss = modify_closure(closure)
|
822
|
+
|
823
|
+
if self.finite_differences:
|
824
|
+
with torch.enable_grad():
|
825
|
+
closure()
|
826
|
+
|
827
|
+
for group in self.param_groups:
|
828
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
829
|
+
p.grad = grads.pop(0)
|
830
|
+
stochastic_add_(g, p.grad, -1)
|
831
|
+
p.hessian_vector = g
|
832
|
+
p.data.copy_(p.orig)
|
833
|
+
del p.orig
|
834
|
+
else:
|
835
|
+
for group in self.param_groups:
|
836
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
837
|
+
p.grad = g
|
838
|
+
params, grads = zip(*[x for group in self.param_groups for x in
|
839
|
+
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
840
|
+
vs = [torch.randn_like(p) for p in params]
|
841
|
+
with torch.enable_grad():
|
842
|
+
hvs = torch.autograd.grad(grads, params, vs)
|
843
|
+
|
844
|
+
for p, g, v, hv in zip(params, grads, vs, hvs):
|
845
|
+
p.hessian_vector = hv
|
846
|
+
p.grad = g
|
847
|
+
p.vector = v
|
848
|
+
|
849
|
+
return loss
|
850
|
+
|
851
|
+
def step(self, closure: Optional[Callable] = None):
|
852
|
+
if self.precond_schedule is None:
|
853
|
+
self._is_preconditioning = False
|
854
|
+
else:
|
855
|
+
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
856
|
+
loss = self._handle_closure(closure)
|
770
857
|
|
771
858
|
# we assume that parameters are constant and that there are no excessive recompiles
|
772
859
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
@@ -774,7 +861,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
774
861
|
group['is_preconditioning'] = self._is_preconditioning
|
775
862
|
self._step(group)
|
776
863
|
if self.use_ema:
|
777
|
-
self.ema_update(
|
864
|
+
self.ema_update()
|
778
865
|
|
779
866
|
return loss
|
780
867
|
|
@@ -784,12 +871,12 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
|
784
871
|
copy_stochastic_(t, s)
|
785
872
|
|
786
873
|
|
787
|
-
|
874
|
+
@decorator_knowngood
|
875
|
+
def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
788
876
|
ea32 = list(map(promote, state))
|
789
877
|
grad = list(map(promote, grad))
|
790
878
|
beta = promote(beta)
|
791
|
-
|
792
|
-
ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
|
879
|
+
ea32 = [e * beta + g * (1 - beta) for e, g in zip(ea32, grad)]
|
793
880
|
copy_stochastic_list_(state, ea32)
|
794
881
|
return ea32
|
795
882
|
|
@@ -801,10 +888,9 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
801
888
|
beta2 = beta_debias(beta2, step)
|
802
889
|
|
803
890
|
g32 = list(map(promote, grad))
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
u32 = torch._foreach_div(exp_avg32, denom)
|
891
|
+
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
892
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
893
|
+
u32 = [ea / d for ea, d in zip(exp_avg32, denom)]
|
808
894
|
copy_stochastic_list_(grad, u32)
|
809
895
|
|
810
896
|
|
@@ -824,9 +910,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
824
910
|
beta2 = beta_debias(beta2, step)
|
825
911
|
|
826
912
|
u32, g32 = [list(map(promote, x)) for x in [update, grad]]
|
827
|
-
|
828
|
-
|
829
|
-
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
913
|
+
exp_avg32 = _lerp(exp_avg, u32, beta1)
|
914
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
|
830
915
|
u32 = torch._foreach_div(exp_avg32, denom)
|
831
916
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
832
917
|
|
@@ -836,7 +921,7 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
|
|
836
921
|
caution: bool):
|
837
922
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
838
923
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
839
|
-
|
924
|
+
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
840
925
|
|
841
926
|
|
842
927
|
@decorator_knowngood
|
@@ -846,17 +931,16 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
846
931
|
beta2 = beta_debias(beta2, step)
|
847
932
|
|
848
933
|
gp32 = list(map(promote, grad))
|
849
|
-
|
850
|
-
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
|
934
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, gp32, beta2, eps, [None])
|
851
935
|
gp32 = torch._foreach_div(gp32, denom)
|
852
|
-
gp32 =
|
853
|
-
|
936
|
+
gp32 = _lerp(exp_avg, gp32, beta1)
|
854
937
|
copy_stochastic_list_(grad, gp32)
|
855
938
|
|
856
939
|
|
857
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
940
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
941
|
+
eps: float = 1e-8):
|
858
942
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
859
|
-
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0]
|
943
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
860
944
|
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
861
945
|
return grad
|
862
946
|
|
@@ -864,23 +948,23 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
864
948
|
@decorator_knowngood
|
865
949
|
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
866
950
|
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
|
867
|
-
caution: bool):
|
951
|
+
caution: bool, eps: Tensor):
|
868
952
|
beta1 = beta_debias(beta1, step)
|
869
953
|
beta2 = beta_debias(beta2, step)
|
870
954
|
|
871
955
|
u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
|
872
|
-
|
873
|
-
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
956
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
|
874
957
|
u32 = torch._foreach_div(u32, denom)
|
875
|
-
u32 =
|
958
|
+
u32 = _lerp(exp_avg, u32, beta1)
|
876
959
|
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
877
960
|
|
878
961
|
|
879
962
|
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
880
|
-
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool
|
963
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
|
964
|
+
eps: float = 1e-8):
|
881
965
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
882
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
883
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
|
966
|
+
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
967
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
884
968
|
|
885
969
|
|
886
970
|
@decorator_knowngood
|
@@ -917,7 +1001,7 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
917
1001
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
918
1002
|
|
919
1003
|
beta2 = beta_debias(beta2, step + 1)
|
920
|
-
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32,
|
1004
|
+
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
921
1005
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
922
1006
|
|
923
1007
|
copy_stochastic_list_(grad, update)
|
@@ -1005,7 +1089,6 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1005
1089
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
1006
1090
|
"""
|
1007
1091
|
letters = string.ascii_lowercase + string.ascii_uppercase
|
1008
|
-
|
1009
1092
|
dtype = dtype if dtype is not None else t.dtype
|
1010
1093
|
shape = t.shape
|
1011
1094
|
|
@@ -1049,11 +1132,9 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1049
1132
|
piece1A.append(letters[i])
|
1050
1133
|
piece2A = piece2A + letters[i]
|
1051
1134
|
piece3A = piece3A + letters[i]
|
1052
|
-
|
1053
1135
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1054
1136
|
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1055
1137
|
exprGs.append(subscripts)
|
1056
|
-
|
1057
1138
|
piece1P.append(letters[i + 13])
|
1058
1139
|
piece2P.append(letters[i + 13])
|
1059
1140
|
piece3P = piece3P + letters[i + 13]
|
@@ -1061,16 +1142,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1061
1142
|
else:
|
1062
1143
|
# use triangular matrix as preconditioner for this dim
|
1063
1144
|
Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
|
1064
|
-
|
1065
1145
|
piece1A.append(letters[i] + letters[i + 13])
|
1066
1146
|
piece2A = piece2A + letters[i + 13]
|
1067
1147
|
piece3A = piece3A + letters[i]
|
1068
|
-
|
1069
1148
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1070
1149
|
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1071
1150
|
subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
|
1072
1151
|
exprGs.append(subscripts)
|
1073
|
-
|
1074
1152
|
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1075
1153
|
piece1P.append(a + b)
|
1076
1154
|
piece2P.append(a + c)
|
@@ -1091,7 +1169,7 @@ def psgd_balance_Q(Q_in):
|
|
1091
1169
|
|
1092
1170
|
|
1093
1171
|
def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
|
1094
|
-
eps = scalar_guard(math.sqrt(torch.finfo(
|
1172
|
+
eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
|
1095
1173
|
eps *= G.norm() / G.numel()
|
1096
1174
|
G = G + torch.randn_like(G) * eps
|
1097
1175
|
md = min_dtype(Q + [G])
|
@@ -1117,9 +1195,7 @@ def psgd_lb(A, max_abs):
|
|
1117
1195
|
A /= max_abs
|
1118
1196
|
a0 = torch.einsum('ij,ij->j', A, A)
|
1119
1197
|
i = torch.argmax(a0)
|
1120
|
-
|
1121
1198
|
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1122
|
-
|
1123
1199
|
x = torch.einsum('i,ij->j', x, A)
|
1124
1200
|
x /= x.norm()
|
1125
1201
|
x = torch.einsum('j,kj->k', x, A)
|
@@ -1132,15 +1208,12 @@ def psgd_lb(A, max_abs):
|
|
1132
1208
|
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
1133
1209
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1134
1210
|
exprA, exprGs, _ = exprs
|
1135
|
-
|
1136
1211
|
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
|
1137
1212
|
|
1138
1213
|
for q, exprG, o in zip(Q, exprGs, oq):
|
1139
1214
|
term1 = promote(torch.einsum(exprG, A, A))
|
1140
1215
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1141
|
-
|
1142
1216
|
term1, term2 = term1 - term2, term1 + term2
|
1143
|
-
|
1144
1217
|
term1 *= precond_lr
|
1145
1218
|
norm = term2.norm(float('inf'))
|
1146
1219
|
if q.dim() < 2:
|
@@ -1256,8 +1329,8 @@ def identity(x):
|
|
1256
1329
|
|
1257
1330
|
@decorator_knowngood
|
1258
1331
|
def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1259
|
-
ema32 =
|
1260
|
-
|
1332
|
+
ema32 = _lerp(ema, p, ema_decay)
|
1333
|
+
_lerp(p, ema32, 1 - weight_decay)
|
1261
1334
|
|
1262
1335
|
|
1263
1336
|
def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
@@ -1267,10 +1340,10 @@ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
|
1267
1340
|
|
1268
1341
|
|
1269
1342
|
@decorator_knowngood
|
1270
|
-
def _compilable_l1_weight_decay_to_ema_(p, ema,
|
1271
|
-
ema32 =
|
1343
|
+
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1344
|
+
ema32 = _lerp(ema, p, ema_decay)
|
1272
1345
|
for p_, e_ in zip(p, ema32):
|
1273
|
-
p32 = promote(
|
1346
|
+
p32 = promote(p_)
|
1274
1347
|
p32 = p32 + (p32 - e_).sign() * weight_decay
|
1275
1348
|
copy_stochastic_(p_, p32)
|
1276
1349
|
|
@@ -1447,7 +1520,6 @@ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps:
|
|
1447
1520
|
|
1448
1521
|
if graft:
|
1449
1522
|
out = _compilable_grafting(g, out)
|
1450
|
-
|
1451
1523
|
copy_stochastic_(g, out)
|
1452
1524
|
|
1453
1525
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.6.0
|
4
4
|
Summary: Efficient optimizers
|
5
|
-
Home-page: https://github.com/
|
6
|
-
Author:
|
5
|
+
Home-page: https://github.com/HomebrewML/HeavyBall
|
6
|
+
Author: HeavyBall Authors
|
7
7
|
Author-email: github.heavyball@nestler.sh
|
8
8
|
License: BSD
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
300
300
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
301
301
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
302
302
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
303
|
-
|
303
|
+
correct_bias: bool = True, warmup_steps: int = 1,
|
304
304
|
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
305
305
|
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
306
306
|
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
|
|
324
324
|
* **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
|
325
325
|
* **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
|
326
326
|
* **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
|
327
|
-
* **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
|
328
327
|
* **`correct_bias`**: Enables/disables bias correction for the running averages.
|
329
328
|
* **`warmup_steps`**: Number of steps for linear learning rate warmup.
|
330
329
|
* **`split`**: Whether to split large dimensions when forming the preconditioner.
|
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
|
|
931
930
|
* **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
|
932
931
|
`compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
|
933
932
|
configurations.
|
934
|
-
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
+
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
+
heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
|
4
|
+
heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
|
6
|
+
heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.6.0.dist-info/RECORD,,
|
heavyball-1.5.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
|
2
|
-
heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
|
3
|
-
heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
|
4
|
-
heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
|
6
|
-
heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|