heavyball 1.4.3__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +36 -5
- heavyball/chainable.py +38 -23
- heavyball/utils.py +159 -101
- {heavyball-1.4.3.dist-info → heavyball-1.5.0.dist-info}/METADATA +1 -1
- heavyball-1.5.0.dist-info/RECORD +8 -0
- heavyball-1.4.3.dist-info/RECORD +0 -8
- {heavyball-1.4.3.dist-info → heavyball-1.5.0.dist-info}/LICENSE +0 -0
- {heavyball-1.4.3.dist-info → heavyball-1.5.0.dist-info}/WHEEL +0 -0
- {heavyball-1.4.3.dist-info → heavyball-1.5.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -116,7 +116,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
116
116
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
117
117
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
118
118
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
119
|
-
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int =
|
119
|
+
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
|
120
120
|
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
121
121
|
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
122
122
|
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
@@ -129,8 +129,10 @@ class ForeachSOAP(C.BaseOpt):
|
|
129
129
|
|
130
130
|
if use_precond_schedule:
|
131
131
|
del defaults['precondition_frequency']
|
132
|
+
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
132
133
|
else:
|
133
134
|
del defaults['precond_scheduler']
|
135
|
+
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
134
136
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
135
137
|
C.scale_by_soap)
|
136
138
|
|
@@ -149,6 +151,30 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
|
|
149
151
|
palm: bool = True
|
150
152
|
|
151
153
|
|
154
|
+
class OrthoLaProp(C.BaseOpt):
|
155
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
156
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
157
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
158
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
159
|
+
defaults = locals()
|
160
|
+
defaults.pop("self")
|
161
|
+
params = defaults.pop("params")
|
162
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
163
|
+
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
|
+
|
165
|
+
|
166
|
+
class OrthoAdamW(C.BaseOpt):
|
167
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
168
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
169
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
170
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
171
|
+
defaults = locals()
|
172
|
+
defaults.pop("self")
|
173
|
+
params = defaults.pop("params")
|
174
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
175
|
+
C.orthogonalize_grad_to_param, C.scale_by_adam)
|
176
|
+
|
177
|
+
|
152
178
|
class ForeachPSGDKron(C.BaseOpt):
|
153
179
|
"""
|
154
180
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -162,7 +188,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
162
188
|
|
163
189
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
164
190
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
165
|
-
momentum_into_precond_update=True, warmup_steps: int =
|
191
|
+
momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
|
166
192
|
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
167
193
|
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
168
194
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
@@ -172,6 +198,8 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
172
198
|
precond_init_scale=1.0, precond_lr=0.1):
|
173
199
|
defaults = locals()
|
174
200
|
defaults.pop("self")
|
201
|
+
self.precond_schedule = defaults.pop(
|
202
|
+
"preconditioner_update_probability") or utils.precond_update_prob_schedule()
|
175
203
|
params = defaults.pop("params")
|
176
204
|
|
177
205
|
delayed = C.default(delayed, self.delayed)
|
@@ -181,8 +209,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
181
209
|
|
182
210
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
|
183
211
|
*(C.exp_avg,) * exp_avg_input, #
|
184
|
-
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached
|
185
|
-
prob=preconditioner_update_probability))
|
212
|
+
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached))
|
186
213
|
|
187
214
|
|
188
215
|
class ForeachPurePSGD(ForeachPSGDKron):
|
@@ -202,6 +229,10 @@ class ForeachDelayedPSGD(ForeachPSGDKron):
|
|
202
229
|
delayed: bool = True
|
203
230
|
|
204
231
|
|
232
|
+
class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
|
233
|
+
hessian_approx = True
|
234
|
+
|
235
|
+
|
205
236
|
PalmForEachSoap = PaLMForeachSOAP
|
206
237
|
PaLMSOAP = PaLMForeachSOAP
|
207
238
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -225,4 +256,4 @@ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD",
|
|
225
256
|
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
226
257
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
227
258
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
228
|
-
"ForeachRMSprop", "ForeachMuon"]
|
259
|
+
"ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
|
heavyball/chainable.py
CHANGED
@@ -127,7 +127,7 @@ def zero_guard(*names):
|
|
127
127
|
|
128
128
|
|
129
129
|
def copy_guard(index, *names):
|
130
|
-
return functools.partial(CopyGuard, index=index, names=names,)
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names, )
|
131
131
|
|
132
132
|
|
133
133
|
def general_guard(*names, init_fn, skip_first: bool = True):
|
@@ -188,6 +188,11 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
188
188
|
raise SkipUpdate
|
189
189
|
|
190
190
|
|
191
|
+
@no_state
|
192
|
+
def orthogonalize_grad_to_param(group, update, grad, param):
|
193
|
+
return utils.orthogonalize_grad_to_param(param, update, group['eps'])
|
194
|
+
|
195
|
+
|
191
196
|
@copy_guard(2, "z")
|
192
197
|
@no_state
|
193
198
|
def update_by_schedule_free(group, update, grad, param, z):
|
@@ -312,17 +317,23 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
312
317
|
|
313
318
|
for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
|
314
319
|
utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
|
315
|
-
utils.beta_debias(group['shampoo_beta'], group['step']),
|
320
|
+
utils.beta_debias(group['shampoo_beta'], group['step']),
|
321
|
+
group['is_preconditioning'])
|
316
322
|
return precond
|
317
323
|
|
318
324
|
|
319
325
|
def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
320
326
|
if prob is None:
|
321
327
|
prob = utils.precond_update_prob_schedule()
|
322
|
-
|
328
|
+
|
329
|
+
if not group['is_preconditioning']:
|
323
330
|
return Q_mat
|
324
331
|
|
325
|
-
utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q,
|
332
|
+
utils.psgd_update_precond(Q_mat, exprs, getattr(param, 'hessian_vector', grad), group['precond_lr'], Q,
|
333
|
+
group['store_triu_as_line'], getattr(param, 'vector', None))
|
334
|
+
if hasattr(param, 'vector'):
|
335
|
+
del param.vector
|
336
|
+
del param.hessian_vector
|
326
337
|
|
327
338
|
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
328
339
|
if group['store_triu_as_line']:
|
@@ -330,7 +341,15 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
330
341
|
else:
|
331
342
|
utils.psgd_balance_Q(Q)
|
332
343
|
|
333
|
-
|
344
|
+
if isinstance(prob, float):
|
345
|
+
float_prob = prob
|
346
|
+
else:
|
347
|
+
float_prob = prob(group.get(f'cumulative_prob_{id(Q)}_prob_step', 1))
|
348
|
+
group['is_cached'] = should_use_cache = cached and float_prob < 0.5
|
349
|
+
|
350
|
+
if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
|
351
|
+
return _update_psgd_cache(cached, Q_cache, Q_mat)
|
352
|
+
return Q_mat
|
334
353
|
|
335
354
|
|
336
355
|
def _update_psgd_cache(cached, Q_cache, q):
|
@@ -345,14 +364,14 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
345
364
|
return Q_cache
|
346
365
|
|
347
366
|
|
348
|
-
def _cached_psgd_precond_grad(
|
349
|
-
if
|
367
|
+
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
|
368
|
+
if group.get('is_cached', False):
|
350
369
|
return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
|
351
370
|
return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
|
352
371
|
|
353
372
|
|
354
|
-
def _fused_cached_psgd_precond_grad(group, grad, param,
|
355
|
-
if
|
373
|
+
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
|
374
|
+
if group.get('is_cached', False):
|
356
375
|
utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
|
357
376
|
group['caution'], *Q_cache)
|
358
377
|
else:
|
@@ -368,7 +387,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
368
387
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
369
388
|
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
370
389
|
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
371
|
-
return _cached_psgd_precond_grad(
|
390
|
+
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
372
391
|
|
373
392
|
|
374
393
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
@@ -376,9 +395,9 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
376
395
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
377
396
|
prob: Optional[callable] = None):
|
378
397
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
379
|
-
precond = _cached_psgd_precond_grad(
|
380
|
-
_update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
381
|
-
|
398
|
+
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
399
|
+
_ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
400
|
+
Q_mat, Q, exprs, prob)
|
382
401
|
return precond
|
383
402
|
|
384
403
|
|
@@ -389,7 +408,7 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
|
|
389
408
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
390
409
|
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
391
410
|
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
392
|
-
_fused_cached_psgd_precond_grad(group, update, param,
|
411
|
+
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
393
412
|
raise SkipUpdate
|
394
413
|
|
395
414
|
|
@@ -398,9 +417,9 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
|
|
398
417
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
399
418
|
prob: Optional[callable] = None):
|
400
419
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
401
|
-
_fused_cached_psgd_precond_grad(group, update, param,
|
402
|
-
_update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
403
|
-
|
420
|
+
_fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
|
421
|
+
_ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
422
|
+
Q_mat, Q, exprs, prob)
|
404
423
|
raise SkipUpdate
|
405
424
|
|
406
425
|
|
@@ -449,6 +468,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
449
468
|
group['base_lr'] = group['lr']
|
450
469
|
|
451
470
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
471
|
+
|
452
472
|
if not vals:
|
453
473
|
return
|
454
474
|
p, g = zip(*vals)
|
@@ -464,12 +484,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
464
484
|
break
|
465
485
|
|
466
486
|
group['step'] = state['step'] = step = step + 1
|
467
|
-
|
468
|
-
if group['warmup_steps'] and step < group['warmup_steps']:
|
469
|
-
group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
470
|
-
|
471
|
-
else:
|
472
|
-
group['prev_lr'] = group['lr'] = group['base_lr']
|
487
|
+
group['prev_lr'] = group['lr'] = group['base_lr'] * step / max(step, group['warmup_steps'] + 1)
|
473
488
|
|
474
489
|
if not group['foreach'] or len(p) == 1:
|
475
490
|
for param, grad in zip(p, g):
|
heavyball/utils.py
CHANGED
@@ -54,12 +54,6 @@ def decorator_knowngood(func: Callable):
|
|
54
54
|
einsum_base = string.ascii_lowercase + string.ascii_uppercase
|
55
55
|
|
56
56
|
|
57
|
-
def warmup(lr: float, step: int, warmup_steps: int):
|
58
|
-
if step >= warmup_steps: # if instead of min to guard against 0 div
|
59
|
-
return lr
|
60
|
-
return lr * step / warmup_steps
|
61
|
-
|
62
|
-
|
63
57
|
@decorator_knowngood
|
64
58
|
def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
|
65
59
|
beta1: Tensor, decay: float, grad: List[Tensor], caution):
|
@@ -323,6 +317,10 @@ def nesterov_momentum(state, grad, beta):
|
|
323
317
|
return grad
|
324
318
|
|
325
319
|
|
320
|
+
def _compilable_grafting(magnitude, direction):
|
321
|
+
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
|
322
|
+
|
323
|
+
|
326
324
|
# mode in ("newtonschulz", "qr", "svd")
|
327
325
|
# scale_mode in ("none", "scale", "graft")
|
328
326
|
@decorator_knowngood
|
@@ -341,12 +339,17 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
|
341
339
|
elif scale_mode == "scale":
|
342
340
|
y *= max(1, x.size(0) / x.size(1)) ** 0.5
|
343
341
|
elif scale_mode == "graft":
|
344
|
-
y
|
342
|
+
y = _compilable_grafting(x, y)
|
345
343
|
else:
|
346
344
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
347
345
|
set_(out, y)
|
348
346
|
|
349
347
|
|
348
|
+
@decorator_knowngood
|
349
|
+
def _compilable_scatter_set(target, source, index):
|
350
|
+
target[:] = source.contiguous()[index].reshape_as(target)
|
351
|
+
|
352
|
+
|
350
353
|
def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
351
354
|
"""
|
352
355
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -381,7 +384,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
381
384
|
|
382
385
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
383
386
|
for i, ind in enumerate(indices))
|
384
|
-
|
387
|
+
_compilable_scatter_set(exp_avg_sq, exp_avg_sq, indices)
|
385
388
|
|
386
389
|
|
387
390
|
def get_orthogonal_matrix(mat):
|
@@ -482,7 +485,7 @@ def scalar_guard(*args):
|
|
482
485
|
out = []
|
483
486
|
for x in xs:
|
484
487
|
if isinstance(x, float):
|
485
|
-
out.append(torch.empty((), dtype=
|
488
|
+
out.append(torch.empty((), dtype=promote(ref.dtype), device=ref.device).fill_(x))
|
486
489
|
elif isinstance(x, int):
|
487
490
|
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
|
488
491
|
else:
|
@@ -500,7 +503,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
500
503
|
copy_stochastic_(x_, x32 + y32 * alpha)
|
501
504
|
|
502
505
|
|
503
|
-
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
|
506
|
+
def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
|
504
507
|
x, y = list_guard(x, y)
|
505
508
|
alpha = scalar_guard(alpha, x[0])
|
506
509
|
_compilable_stochastic_add_(x, y, alpha)
|
@@ -591,25 +594,26 @@ def project(grad, Q, back: bool):
|
|
591
594
|
class StatefulOptimizer(torch.optim.Optimizer):
|
592
595
|
ema_decay: float = 0.001
|
593
596
|
compile_step: bool = False
|
597
|
+
hessian_approx: bool = False
|
598
|
+
precond_schedule: Union[Callable, float, None] = None
|
599
|
+
stochastic_schedule: bool = False
|
594
600
|
|
595
601
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
596
602
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
597
|
-
self.fake_groups = {}
|
598
603
|
self.use_ema = use_ema
|
599
604
|
self.mapping = {}
|
605
|
+
self._inner_group = {'stochastic_schedule': self.stochastic_schedule}
|
606
|
+
self._precond_rng = random.Random(0x12312)
|
607
|
+
self._is_preconditioning = None
|
600
608
|
|
601
|
-
|
602
|
-
|
603
|
-
return [group]
|
604
|
-
|
605
|
-
for p in group['params']:
|
606
|
-
if p not in self.fake_groups:
|
607
|
-
self.fake_groups[p] = {**group, 'params': [p]}
|
609
|
+
if self.hessian_approx and self.compile_step:
|
610
|
+
raise ValueError("Hessian approximation can't be used with compile_step.")
|
608
611
|
|
609
|
-
|
612
|
+
def get_groups(self, group):
|
613
|
+
return [group]
|
610
614
|
|
611
615
|
def state_(self, arg: Tensor):
|
612
|
-
return self.state[
|
616
|
+
return self.state[arg]
|
613
617
|
|
614
618
|
def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
|
615
619
|
for p, g in zip(p_list, g_list):
|
@@ -622,36 +626,27 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
622
626
|
def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
|
623
627
|
beta1: float = -1.0):
|
624
628
|
for p in group["params"]:
|
625
|
-
if p
|
626
|
-
|
627
|
-
continue
|
628
|
-
grad = None
|
629
|
+
if p in self.mapping:
|
630
|
+
p_views = self.mapping[p]
|
629
631
|
else:
|
630
|
-
|
631
|
-
grad = promote(p.grad)
|
632
|
-
else:
|
633
|
-
grad = p.grad
|
634
|
-
if beta1 >= 0 and group.get('mars', False):
|
635
|
-
self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
|
636
|
-
|
637
|
-
p.grad = None
|
632
|
+
self.mapping[p] = p_views = merge_group(group, p)
|
638
633
|
|
639
|
-
|
640
|
-
|
641
|
-
continue
|
634
|
+
grad = getattr(p, 'grad', None)
|
635
|
+
p.grad = None
|
642
636
|
|
643
|
-
p_views = merge_group(group, p)
|
644
|
-
if grad is not None:
|
645
|
-
grad = merge_group(group, grad)
|
646
|
-
for i, pv in enumerate(p_views):
|
647
|
-
self.mapping[pv] = (p, i)
|
648
|
-
if isinstance(p_views, Tensor):
|
649
|
-
yield p_views, grad
|
650
|
-
continue
|
651
637
|
if grad is None:
|
652
|
-
|
653
|
-
|
654
|
-
|
638
|
+
grad = [getattr(pv, 'grad', None) for pv in p_views]
|
639
|
+
else:
|
640
|
+
grad = merge_group(group, grad)
|
641
|
+
|
642
|
+
for pv, g in zip(p_views, grad):
|
643
|
+
if skip_none and g is None:
|
644
|
+
continue
|
645
|
+
if should_promote:
|
646
|
+
g = promote(g)
|
647
|
+
if beta1 >= 0 and group.get('mars', False):
|
648
|
+
self.mars_correct_list(group, [pv], [g], group['mars_gamma'], beta1)
|
649
|
+
yield pv, g
|
655
650
|
|
656
651
|
def state_size(self) -> int:
|
657
652
|
total_bytes = 0
|
@@ -671,67 +666,89 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
671
666
|
|
672
667
|
def ema_update(self):
|
673
668
|
with torch.no_grad():
|
674
|
-
for
|
675
|
-
for
|
676
|
-
active_p = [p for p in group['params']]
|
669
|
+
for group in self.param_groups:
|
670
|
+
active_p = [p for p in group['params']]
|
677
671
|
|
678
|
-
|
679
|
-
|
672
|
+
if not active_p:
|
673
|
+
return
|
680
674
|
|
681
|
-
|
675
|
+
k = group['ema_step'] = group.get('ema_step', -1) + 1
|
682
676
|
|
683
|
-
|
684
|
-
|
685
|
-
|
677
|
+
for p in active_p:
|
678
|
+
if 'param_ema' not in self.state_(p):
|
679
|
+
self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
686
680
|
|
687
|
-
|
688
|
-
|
681
|
+
y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
|
682
|
+
torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
|
689
683
|
|
690
684
|
def copy_emas_to_params(self):
|
691
685
|
with torch.no_grad():
|
692
|
-
for
|
693
|
-
for
|
694
|
-
active_p = [p for p in group['params']]
|
686
|
+
for group in self.param_groups:
|
687
|
+
active_p = [p for p in group['params']]
|
695
688
|
|
696
|
-
|
697
|
-
|
689
|
+
if not active_p:
|
690
|
+
return
|
698
691
|
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
692
|
+
for p in active_p:
|
693
|
+
if 'param_ema' in self.state_(p):
|
694
|
+
p_clone = p.data.clone()
|
695
|
+
set_(p.data, self.state_(p)['param_ema'])
|
696
|
+
set_(self.state_(p)['param_ema'], p_clone)
|
704
697
|
|
705
698
|
def copy_params_to_emas(self):
|
706
699
|
with torch.no_grad():
|
707
|
-
for
|
708
|
-
for
|
709
|
-
active_p = [p for p in group['params']]
|
700
|
+
for group in self.param_groups:
|
701
|
+
active_p = [p for p in group['params']]
|
710
702
|
|
711
|
-
|
712
|
-
|
703
|
+
if not active_p:
|
704
|
+
return
|
713
705
|
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
706
|
+
for p in active_p:
|
707
|
+
if 'param_ema' in self.state_(p):
|
708
|
+
ema_clone = self.state_(p)['param_ema'].data.clone()
|
709
|
+
set_(self.state_(p)['param_ema'], p.data)
|
710
|
+
set_(p.data, ema_clone)
|
719
711
|
|
720
712
|
def step(self, closure: Optional[Callable] = None):
|
713
|
+
if self.precond_schedule is None:
|
714
|
+
self._is_preconditioning = False
|
715
|
+
else:
|
716
|
+
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
717
|
+
hessian_approx = self.hessian_approx and self._is_preconditioning
|
721
718
|
if closure is None:
|
719
|
+
if hessian_approx:
|
720
|
+
raise ValueError("Hessian approximation requires a closure.")
|
722
721
|
loss = None
|
723
722
|
else:
|
724
723
|
with torch.enable_grad():
|
725
724
|
loss = closure()
|
725
|
+
if hessian_approx:
|
726
|
+
grads = []
|
727
|
+
for group in self.param_groups:
|
728
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
729
|
+
grads.append(g)
|
730
|
+
p.vector = torch.randn_like(p)
|
731
|
+
p.orig = p.data.clone()
|
732
|
+
stochastic_add_(p.data, p.vector, tiny_bf16)
|
733
|
+
|
734
|
+
with torch.enable_grad():
|
735
|
+
closure()
|
736
|
+
|
737
|
+
for group in self.param_groups:
|
738
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
739
|
+
p.grad = grads.pop(0)
|
740
|
+
stochastic_add_(g, p.grad, -1)
|
741
|
+
p.hessian_vector = g
|
742
|
+
p.data.copy_(p.orig)
|
743
|
+
del p.orig
|
726
744
|
|
727
745
|
# we assume that parameters are constant and that there are no excessive recompiles
|
728
746
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
729
|
-
for
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
self.ema_update(group)
|
747
|
+
for group in self.param_groups:
|
748
|
+
group['is_preconditioning'] = self._is_preconditioning
|
749
|
+
self._step(group)
|
750
|
+
if self.use_ema:
|
751
|
+
self.ema_update(group)
|
735
752
|
|
736
753
|
return loss
|
737
754
|
|
@@ -943,6 +960,15 @@ def precond_schedule(step, precond_scheduler, rng):
|
|
943
960
|
return update_precond
|
944
961
|
|
945
962
|
|
963
|
+
def get_soap_precond_schedule(precond_scheduler):
|
964
|
+
rng = random.Random(0x12312)
|
965
|
+
|
966
|
+
def _inner(step):
|
967
|
+
return precond_schedule(step, precond_scheduler, rng)
|
968
|
+
|
969
|
+
return _inner
|
970
|
+
|
971
|
+
|
946
972
|
def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
|
947
973
|
"""For a scalar or tensor t, we initialize its preconditioner Q and
|
948
974
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
@@ -1030,20 +1056,24 @@ def psgd_balance_Q(Q_in):
|
|
1030
1056
|
torch._foreach_mul_(Q_in, list(norms))
|
1031
1057
|
|
1032
1058
|
|
1033
|
-
def psgd_calc_A_and_conjB(exprA, G, Q):
|
1059
|
+
def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
|
1034
1060
|
eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
|
1035
1061
|
eps *= G.norm() / G.numel()
|
1036
1062
|
G = G + torch.randn_like(G) * eps
|
1037
1063
|
md = min_dtype(Q + [G])
|
1038
1064
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
1039
1065
|
order = G.dim()
|
1040
|
-
|
1066
|
+
if V is None:
|
1067
|
+
conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
|
1068
|
+
else:
|
1069
|
+
conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
|
1041
1070
|
Q = [promote(q) for q in Q]
|
1042
1071
|
for i, q in enumerate(Q):
|
1043
1072
|
if q.dim() <= 1:
|
1044
1073
|
conjB /= q
|
1045
1074
|
else:
|
1046
|
-
conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
|
1075
|
+
conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
|
1076
|
+
conjB)
|
1047
1077
|
if i < order - 1:
|
1048
1078
|
conjB = torch.transpose(conjB, i, order - 1)
|
1049
1079
|
return A, conjB
|
@@ -1065,11 +1095,11 @@ def psgd_lb(A, max_abs):
|
|
1065
1095
|
|
1066
1096
|
|
1067
1097
|
@decorator
|
1068
|
-
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
1098
|
+
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
1069
1099
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1070
1100
|
exprA, exprGs, _ = exprs
|
1071
1101
|
|
1072
|
-
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
|
1102
|
+
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
|
1073
1103
|
|
1074
1104
|
for q, exprG, o in zip(Q, exprGs, oq):
|
1075
1105
|
term1 = promote(torch.einsum(exprG, A, A))
|
@@ -1286,7 +1316,6 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
|
|
1286
1316
|
|
1287
1317
|
|
1288
1318
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1289
|
-
|
1290
1319
|
lr = scalar_guard(lr, param[0])
|
1291
1320
|
_compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
|
1292
1321
|
|
@@ -1325,6 +1354,29 @@ def mars_correction(g, old_g, beta1, gamma):
|
|
1325
1354
|
_compilable_mars_correction_(g, old_g, a)
|
1326
1355
|
|
1327
1356
|
|
1357
|
+
@decorator_knowngood
|
1358
|
+
def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps: Tensor, graft: bool = True):
|
1359
|
+
"""
|
1360
|
+
Implements OrthoGrad from "Grokking at the Edge of Numerical Stability" (https://arxiv.org/abs/2501.04697)
|
1361
|
+
"""
|
1362
|
+
|
1363
|
+
for w, g in zip(weight, grad):
|
1364
|
+
proj = promote((w * g).sum()) / promote((w * w).sum()).add(eps)
|
1365
|
+
out = promote(g) - proj * promote(w) # promote in this funky way to keep traffic minimal
|
1366
|
+
|
1367
|
+
if graft:
|
1368
|
+
out = _compilable_grafting(g, out)
|
1369
|
+
|
1370
|
+
copy_stochastic_(g, out)
|
1371
|
+
|
1372
|
+
|
1373
|
+
def orthogonalize_grad_to_param(weight, grad, eps, graft=True):
|
1374
|
+
weight, grad = list_guard(weight, grad)
|
1375
|
+
eps = scalar_guard(eps, weight[0])
|
1376
|
+
_compilable_orthogonalization(weight, grad, eps, graft)
|
1377
|
+
return grad
|
1378
|
+
|
1379
|
+
|
1328
1380
|
@decorator_knowngood
|
1329
1381
|
def _compilable_cautioning(g: Tensor, update: Tensor):
|
1330
1382
|
mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
|
@@ -1338,25 +1390,20 @@ def caution(g, update):
|
|
1338
1390
|
return _compilable_cautioning(g, update)
|
1339
1391
|
|
1340
1392
|
|
1341
|
-
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.
|
1393
|
+
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_start=1000):
|
1342
1394
|
"""Anneal preconditioner update probability during beginning of training.
|
1343
1395
|
|
1344
1396
|
PSGD benefits from more preconditioner updates at the beginning of training,
|
1345
1397
|
but once the preconditioner is learned the update probability can drop low.
|
1346
1398
|
|
1347
1399
|
This schedule is an exponential anneal with a flat start. Default settings keep
|
1348
|
-
update probability at
|
1349
|
-
`min_prob` by 4000 steps. Default settings work very well for most models and
|
1400
|
+
update probability at `max_prob` for 1000 steps then exponentially anneal down to
|
1401
|
+
`min_prob` by ~4000 steps. Default settings work very well for most models and
|
1350
1402
|
training regimes.
|
1351
1403
|
"""
|
1352
1404
|
|
1353
1405
|
def _schedule(n):
|
1354
|
-
|
1355
|
-
return max_prob
|
1356
|
-
|
1357
|
-
n -= flat_start
|
1358
|
-
prob = max_prob * math.exp(-decay * (n - flat_start))
|
1359
|
-
return max(min_prob, min(max_prob, prob))
|
1406
|
+
return max(min_prob, max_prob * decay ** max(n - flat_start, 0))
|
1360
1407
|
|
1361
1408
|
return _schedule
|
1362
1409
|
|
@@ -1375,12 +1422,18 @@ def merge_group(group, *tensors):
|
|
1375
1422
|
|
1376
1423
|
|
1377
1424
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1378
|
-
|
1425
|
+
optimizers = {}
|
1426
|
+
|
1427
|
+
def _step(p: Tensor):
|
1428
|
+
o = optimizers[p]
|
1379
1429
|
o.step()
|
1380
1430
|
o.zero_grad()
|
1381
1431
|
|
1382
1432
|
for p in model.parameters():
|
1383
|
-
p
|
1433
|
+
optimizers[p] = optimizer([p], *args, **kwargs)
|
1434
|
+
p.register_post_accumulate_grad_hook(_step)
|
1435
|
+
|
1436
|
+
return optimizers
|
1384
1437
|
|
1385
1438
|
|
1386
1439
|
def fused_hook(parameters, optimizer, *args, **kwargs):
|
@@ -1389,18 +1442,23 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1389
1442
|
seen_params = set()
|
1390
1443
|
|
1391
1444
|
o = optimizer(parameters, *args, **kwargs)
|
1445
|
+
step_fn = o.step
|
1446
|
+
o.step = functools.partial(warn_once,
|
1447
|
+
msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
|
1392
1448
|
|
1393
1449
|
def _step(p: Tensor):
|
1394
1450
|
seen_params.add(p)
|
1395
1451
|
|
1396
1452
|
if len(seen_params) < param_count:
|
1397
|
-
|
1453
|
+
step_fn()
|
1398
1454
|
o.zero_grad()
|
1399
1455
|
seen_params.clear()
|
1400
1456
|
|
1401
1457
|
for p in parameters:
|
1402
1458
|
p.register_post_accumulate_grad_hook(_step)
|
1403
1459
|
|
1460
|
+
return o
|
1461
|
+
|
1404
1462
|
|
1405
1463
|
@decorator_knowngood
|
1406
1464
|
def _compilable_caution_no_scale(g: Tensor, update: Tensor):
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=AL3oSbNB1HQ0cwEG6aPZGVMbpCXXCYOxREX7JwK4Byc,12773
|
2
|
+
heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
|
3
|
+
heavyball/utils.py,sha256=NFvQcQemNOugH1vAi_UH3jnnttPSgVopmS1q6jbhxkQ,50289
|
4
|
+
heavyball-1.5.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.5.0.dist-info/METADATA,sha256=fUOCJvDcBQ5280TCLhUCuIRwNVMvp3ysp4qrDuJCUeI,43584
|
6
|
+
heavyball-1.5.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.5.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.5.0.dist-info/RECORD,,
|
heavyball-1.4.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
|
3
|
-
heavyball/utils.py,sha256=x0rSU8lko7ACdI9GuTLC0wP6HwIZxwB8f8tukBOR0xA,48129
|
4
|
-
heavyball-1.4.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.4.3.dist-info/METADATA,sha256=RM_pOme3dsQL-drKcKD6FJ0qE3SSh4JdPM-kC9vpbeU,43584
|
6
|
-
heavyball-1.4.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.4.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.4.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|