heavyball 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +4 -26
- heavyball/chainable.py +73 -9
- heavyball/utils.py +97 -17
- {heavyball-1.5.1.dist-info → heavyball-1.5.3.dist-info}/METADATA +1 -1
- heavyball-1.5.3.dist-info/RECORD +8 -0
- heavyball-1.5.1.dist-info/RECORD +0 -8
- {heavyball-1.5.1.dist-info → heavyball-1.5.3.dist-info}/LICENSE +0 -0
- {heavyball-1.5.1.dist-info → heavyball-1.5.3.dist-info}/WHEEL +0 -0
- {heavyball-1.5.1.dist-info → heavyball-1.5.3.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -163,18 +163,8 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
163
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
164
|
|
165
165
|
|
166
|
-
class ForeachAdamW(C.BaseOpt):
|
167
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
168
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
169
|
-
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
170
|
-
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
171
|
-
defaults = locals()
|
172
|
-
defaults.pop("self")
|
173
|
-
params = defaults.pop("params")
|
174
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
175
|
-
|
176
166
|
|
177
|
-
class
|
167
|
+
class LaPropOrtho(C.BaseOpt):
|
178
168
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
179
169
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
180
170
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
@@ -183,19 +173,7 @@ class OrthoAdamW(C.BaseOpt):
|
|
183
173
|
defaults.pop("self")
|
184
174
|
params = defaults.pop("params")
|
185
175
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
186
|
-
C.
|
187
|
-
|
188
|
-
|
189
|
-
class AdamWOrtho(C.BaseOpt):
|
190
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
191
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
192
|
-
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
193
|
-
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
194
|
-
defaults = locals()
|
195
|
-
defaults.pop("self")
|
196
|
-
params = defaults.pop("params")
|
197
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
|
198
|
-
C.orthogonalize_grad_to_param)
|
176
|
+
C.scale_by_laprop, C.orthogonalize_grad_to_param)
|
199
177
|
|
200
178
|
|
201
179
|
class ForeachPSGDKron(C.BaseOpt):
|
@@ -216,7 +194,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
216
194
|
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
217
195
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
218
196
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
219
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
197
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
220
198
|
# expert parameters
|
221
199
|
precond_init_scale=1.0, precond_lr=0.1):
|
222
200
|
defaults = locals()
|
@@ -279,4 +257,4 @@ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD",
|
|
279
257
|
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
280
258
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
281
259
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
282
|
-
"ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
|
260
|
+
"ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho']
|
heavyball/chainable.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
from typing import Optional, Union, Literal
|
3
|
+
from typing import Optional, Union, Literal, List
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
152
152
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
153
153
|
|
154
154
|
|
155
|
+
@zero_guard('exp_avg')
|
156
|
+
@no_state
|
157
|
+
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
158
|
+
utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
159
|
+
group['weight_decay_to_ema'] * group['lr'])
|
160
|
+
return update
|
161
|
+
|
162
|
+
|
163
|
+
@zero_guard('exp_avg')
|
164
|
+
@no_state
|
165
|
+
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
166
|
+
utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
167
|
+
group['weight_decay_to_ema'] * group['lr'])
|
168
|
+
return update
|
169
|
+
|
170
|
+
|
155
171
|
@zero_guard("exp_avg_sq")
|
156
172
|
@no_state
|
157
173
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
@@ -295,6 +311,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
|
|
295
311
|
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
296
312
|
|
297
313
|
|
314
|
+
@zero_guard('momentum')
|
315
|
+
@no_state
|
316
|
+
def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
|
317
|
+
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
|
318
|
+
|
319
|
+
|
320
|
+
def _store_std(state, group, update, grad, param):
|
321
|
+
state['init_std'] = torch.std(grad, dim=0)
|
322
|
+
|
323
|
+
|
324
|
+
@general_guard("init_std", init_fn=_store_std)
|
325
|
+
@no_state
|
326
|
+
def mup_approx(group, updates, grads, params, init_std):
|
327
|
+
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
328
|
+
_updates, _init_std = zip(*_updates)
|
329
|
+
utils.stochastic_multiply_(_updates, _init_std)
|
330
|
+
return updates
|
331
|
+
|
332
|
+
|
298
333
|
@zero_guard("momentum")
|
299
334
|
@no_state
|
300
335
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
@@ -312,7 +347,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
312
347
|
|
313
348
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
314
349
|
fn = _optim_fns[inner]
|
315
|
-
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
|
350
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'], group['eps'])
|
316
351
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
317
352
|
|
318
353
|
for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
|
@@ -364,10 +399,12 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
364
399
|
return Q_cache
|
365
400
|
|
366
401
|
|
367
|
-
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
|
402
|
+
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
|
368
403
|
if group.get('is_cached', False):
|
369
|
-
|
370
|
-
|
404
|
+
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
|
405
|
+
out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
|
406
|
+
group['caution'] = False # we already cautioned here - shouldn't do it again
|
407
|
+
return out
|
371
408
|
|
372
409
|
|
373
410
|
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
|
@@ -387,7 +424,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
387
424
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
388
425
|
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
389
426
|
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
390
|
-
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
427
|
+
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
391
428
|
|
392
429
|
|
393
430
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
@@ -395,7 +432,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
395
432
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
396
433
|
prob: Optional[callable] = None):
|
397
434
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
398
|
-
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
435
|
+
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
399
436
|
_ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
400
437
|
Q_mat, Q, exprs, prob)
|
401
438
|
return precond
|
@@ -412,6 +449,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
|
|
412
449
|
raise SkipUpdate
|
413
450
|
|
414
451
|
|
452
|
+
@no_state
|
453
|
+
def sign(group, update, grad, param, graft: bool = True):
|
454
|
+
return utils.sign_(update, graft)
|
455
|
+
|
456
|
+
|
415
457
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
416
458
|
@no_state_no_foreach
|
417
459
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
@@ -437,8 +479,7 @@ def apply_to_idx(fn, idx):
|
|
437
479
|
return _fn
|
438
480
|
|
439
481
|
|
440
|
-
def
|
441
|
-
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
482
|
+
def _inner_chain(state, group, update, grad, param, *fns):
|
442
483
|
skip_update = False
|
443
484
|
for fn in fns:
|
444
485
|
try:
|
@@ -448,10 +489,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
448
489
|
continue
|
449
490
|
if update is None:
|
450
491
|
break
|
492
|
+
return update, skip_update
|
493
|
+
|
494
|
+
|
495
|
+
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
496
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
497
|
+
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
|
451
498
|
if not skip_update and update is not None:
|
452
499
|
utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
|
453
500
|
|
454
501
|
|
502
|
+
def create_branch(branches: List[List[callable]], merge_fn: callable):
|
503
|
+
def _branch(state, group, update, grad, param):
|
504
|
+
outputs = []
|
505
|
+
for branch in branches:
|
506
|
+
branch_update = [torch.clone(g, memory_format=torch.preserve_format) for u in update]
|
507
|
+
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
508
|
+
if skip_update:
|
509
|
+
raise ValueError("Branches should not skip updates")
|
510
|
+
outputs.append(branch_update)
|
511
|
+
return merge_fn(outputs)
|
512
|
+
|
513
|
+
return _branch
|
514
|
+
|
515
|
+
|
455
516
|
class ChainOpt(utils.StatefulOptimizer):
|
456
517
|
promote: bool = False
|
457
518
|
|
@@ -467,6 +528,8 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
467
528
|
f'only supported with foreach=True (currently foreach={group["foreach"]}).')
|
468
529
|
group['base_lr'] = group['lr']
|
469
530
|
|
531
|
+
caution = group['caution']
|
532
|
+
|
470
533
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
471
534
|
|
472
535
|
if not vals:
|
@@ -492,6 +555,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
492
555
|
else:
|
493
556
|
chain(self.state_, group, g, p, *self.fns)
|
494
557
|
|
558
|
+
group['caution'] = caution
|
495
559
|
group['lr'] = group['prev_lr']
|
496
560
|
group['step'] = None
|
497
561
|
|
heavyball/utils.py
CHANGED
@@ -317,6 +317,19 @@ def nesterov_momentum(state, grad, beta):
|
|
317
317
|
return grad
|
318
318
|
|
319
319
|
|
320
|
+
@decorator_knowngood
|
321
|
+
def _compilable_nesterov_ema_(state, grad, beta):
|
322
|
+
ema32 = _lerp32(state, grad, beta)
|
323
|
+
stochastic_add_(grad, ema32, 1)
|
324
|
+
|
325
|
+
|
326
|
+
def nesterov_ema(state, grad, beta):
|
327
|
+
state, grad = list_guard(state, grad)
|
328
|
+
beta = scalar_guard(beta, state[0])
|
329
|
+
_compilable_nesterov_ema_(state, grad, beta)
|
330
|
+
return grad
|
331
|
+
|
332
|
+
|
320
333
|
def _compilable_grafting(magnitude, direction):
|
321
334
|
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
|
322
335
|
|
@@ -509,6 +522,19 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
|
|
509
522
|
_compilable_stochastic_add_(x, y, alpha)
|
510
523
|
|
511
524
|
|
525
|
+
@decorator_knowngood
|
526
|
+
def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
527
|
+
for x_, y_ in zip(x, y):
|
528
|
+
x32 = promote(x_)
|
529
|
+
y32 = promote(y_)
|
530
|
+
copy_stochastic_(x_, x32 * y32)
|
531
|
+
|
532
|
+
|
533
|
+
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
534
|
+
x, y = list_guard(x, y)
|
535
|
+
_compilable_stochastic_multiply_(x, y)
|
536
|
+
|
537
|
+
|
512
538
|
@decorator
|
513
539
|
def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
514
540
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
@@ -783,7 +809,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
783
809
|
|
784
810
|
|
785
811
|
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
786
|
-
eps: float):
|
812
|
+
eps: float = 1e-8):
|
787
813
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
788
814
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
789
815
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -815,23 +841,23 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
|
|
815
841
|
|
816
842
|
@decorator_knowngood
|
817
843
|
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
818
|
-
beta2: Tensor, step: Tensor):
|
844
|
+
beta2: Tensor, step: Tensor, eps: Tensor):
|
819
845
|
beta1 = beta_debias(beta1, step)
|
820
846
|
beta2 = beta_debias(beta2, step)
|
821
847
|
|
822
848
|
gp32 = list(map(promote, grad))
|
823
849
|
|
824
|
-
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2,
|
850
|
+
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, eps)
|
825
851
|
gp32 = torch._foreach_div(gp32, denom)
|
826
852
|
gp32 = _lerp32(exp_avg, gp32, beta1)
|
827
853
|
|
828
854
|
copy_stochastic_list_(grad, gp32)
|
829
855
|
|
830
856
|
|
831
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
857
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, eps: float = 1e-8):
|
832
858
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
833
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
834
|
-
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
859
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, exp_avg[0], eps)
|
860
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
835
861
|
return grad
|
836
862
|
|
837
863
|
|
@@ -970,6 +996,10 @@ def get_soap_precond_schedule(precond_scheduler):
|
|
970
996
|
return _inner
|
971
997
|
|
972
998
|
|
999
|
+
def _max_idx(x: List[int]):
|
1000
|
+
return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
|
1001
|
+
|
1002
|
+
|
973
1003
|
def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
|
974
1004
|
"""For a scalar or tensor t, we initialize its preconditioner Q and
|
975
1005
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
@@ -992,17 +1022,20 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
992
1022
|
|
993
1023
|
scale = scale ** (1 / len(shape))
|
994
1024
|
|
1025
|
+
dim_diag = [False for _ in shape]
|
995
1026
|
if memory_save_mode is None:
|
996
|
-
|
1027
|
+
pass
|
997
1028
|
elif memory_save_mode == "one_diag":
|
998
|
-
|
999
|
-
|
1000
|
-
|
1029
|
+
dim_diag[_max_idx(shape)] = True
|
1030
|
+
elif memory_save_mode == "smart_one_diag":
|
1031
|
+
sorted_shape = sorted(shape)
|
1032
|
+
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1033
|
+
dim_diag[_max_idx(shape)] = True
|
1001
1034
|
elif memory_save_mode == "all_diag":
|
1002
1035
|
dim_diag = [True for _ in shape]
|
1003
1036
|
else:
|
1004
1037
|
raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
|
1005
|
-
"[None, 'one_diag', 'all_diag']")
|
1038
|
+
"[None, 'one_diag', 'all_diag', 'smart_one_diag']")
|
1006
1039
|
|
1007
1040
|
Q = []
|
1008
1041
|
piece1A, piece2A, piece3A = ([], "", "")
|
@@ -1221,6 +1254,48 @@ def identity(x):
|
|
1221
1254
|
return x
|
1222
1255
|
|
1223
1256
|
|
1257
|
+
@decorator_knowngood
|
1258
|
+
def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1259
|
+
ema32 = _lerp32(ema, p, ema_decay)
|
1260
|
+
_lerp32(p, ema32, 1 - weight_decay)
|
1261
|
+
|
1262
|
+
|
1263
|
+
def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1264
|
+
p, ema = list_guard(p, ema)
|
1265
|
+
ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
|
1266
|
+
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1267
|
+
|
1268
|
+
|
1269
|
+
@decorator_knowngood
|
1270
|
+
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_deacy, weight_decay):
|
1271
|
+
ema32 = _lerp32(ema, p, ema_deacy)
|
1272
|
+
for p_, e_ in zip(p, ema32):
|
1273
|
+
p32 = promote(p)
|
1274
|
+
p32 = p32 + (p32 - e_).sign() * weight_decay
|
1275
|
+
copy_stochastic_(p_, p32)
|
1276
|
+
|
1277
|
+
|
1278
|
+
def l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1279
|
+
p, ema = list_guard(p, ema)
|
1280
|
+
ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
|
1281
|
+
_compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1282
|
+
|
1283
|
+
|
1284
|
+
@decorator_knowngood
|
1285
|
+
def _compilable_sign_(grad: List[Tensor], graft: bool):
|
1286
|
+
for g_ in grad:
|
1287
|
+
gs = g_.sign()
|
1288
|
+
if graft:
|
1289
|
+
gs = _compilable_grafting(g_, gs)
|
1290
|
+
copy_stochastic_(g_, gs)
|
1291
|
+
|
1292
|
+
|
1293
|
+
def sign_(grad: List[Tensor], graft: bool = True):
|
1294
|
+
grad = list_guard(grad)
|
1295
|
+
_compilable_sign_(grad, graft)
|
1296
|
+
return grad
|
1297
|
+
|
1298
|
+
|
1224
1299
|
@decorator_knowngood
|
1225
1300
|
def _compilable_trust_region_clip_(grad, lerp, scale):
|
1226
1301
|
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
@@ -1300,7 +1375,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1300
1375
|
|
1301
1376
|
|
1302
1377
|
@decorator_knowngood
|
1303
|
-
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor,
|
1378
|
+
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
|
1379
|
+
cast: bool = True):
|
1380
|
+
if caution:
|
1381
|
+
ea = _compilable_cautioning(grad, ea)
|
1304
1382
|
md = min_dtype(list(cached_q) + [ea])
|
1305
1383
|
args = [q.to(md) for q in cached_q]
|
1306
1384
|
args = args + [ea.to(md)]
|
@@ -1312,8 +1390,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
|
|
1312
1390
|
|
1313
1391
|
@decorator_knowngood
|
1314
1392
|
def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1315
|
-
precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
|
1316
|
-
update_param_(param, precond, lr, decay, caution=
|
1393
|
+
precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
|
1394
|
+
update_param_(param, precond, lr, decay, caution=False)
|
1317
1395
|
|
1318
1396
|
|
1319
1397
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
@@ -1322,7 +1400,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
|
|
1322
1400
|
|
1323
1401
|
|
1324
1402
|
@decorator_knowngood
|
1325
|
-
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
1403
|
+
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
|
1404
|
+
if caution:
|
1405
|
+
ea = _compilable_cautioning(grad, ea)
|
1326
1406
|
md = min_dtype(list(preconds) + [ea])
|
1327
1407
|
args = [q.to(md) for q in preconds]
|
1328
1408
|
args = args + args + [ea.to(md)]
|
@@ -1332,8 +1412,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
|
1332
1412
|
|
1333
1413
|
@decorator_knowngood
|
1334
1414
|
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1335
|
-
precond = psgd_precond_grad(expr, ea, *preconds)
|
1336
|
-
update_param_(param, precond, lr, decay, caution=
|
1415
|
+
precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
|
1416
|
+
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
1337
1417
|
|
1338
1418
|
|
1339
1419
|
def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=Ex6GLyySA-wL2tNNqn9FHHy4I5CmqvhqDkaeBvyGEn0,12806
|
2
|
+
heavyball/chainable.py,sha256=W3tLXPXMWtzWNbPllEKtAh8W2HSD69NBBZtoO8egsew,27099
|
3
|
+
heavyball/utils.py,sha256=Dtb9QEWRAXzUMHqbOIefjJnteje_Xw6J-Mk-Y4TM9p0,52930
|
4
|
+
heavyball-1.5.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.5.3.dist-info/METADATA,sha256=ovxnzDu2GP9mdt9fmCUZPWAQvWEg0EYr6X1Vfu_SzO0,43584
|
6
|
+
heavyball-1.5.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.5.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.5.3.dist-info/RECORD,,
|
heavyball-1.5.1.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=fz-jC7m7XIYNf4PRaJ0rkSnWPYzMWEK5JQl4vp_yw_w,14166
|
2
|
-
heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
|
3
|
-
heavyball/utils.py,sha256=hae6gPVONG5lZiKm-Wqk0Sjjq3prfZIjCP5UoWcpptA,50338
|
4
|
-
heavyball-1.5.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.1.dist-info/METADATA,sha256=ww9KSe8MJDnjz1blmtnubpE20bkuXJ8NeMOeDK40OJk,43584
|
6
|
-
heavyball-1.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|