heavyball 1.3.0__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.3.0 → heavyball-1.4.0}/PKG-INFO +1 -1
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball/chainable.py +91 -37
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball/utils.py +85 -52
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.3.0 → heavyball-1.4.0}/setup.py +1 -1
- {heavyball-1.3.0 → heavyball-1.4.0}/LICENSE +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/README.md +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball/__init__.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/setup.cfg +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_bf16_params.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_bf16_q.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_bf16_storage.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_caution.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_channels_last.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_closure.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_ema.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_foreach.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_hook.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_mars.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_memory.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_merge.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_no_grad.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_psgd.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_soap.py +0 -0
- {heavyball-1.3.0 → heavyball-1.4.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,5 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
import warnings
|
4
3
|
from typing import Optional, Union, Literal
|
5
4
|
|
6
5
|
import torch
|
@@ -85,10 +84,11 @@ class CopyGuard(FunctionTransform):
|
|
85
84
|
|
86
85
|
|
87
86
|
class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
|
88
|
-
def __init__(self, fn, names, init_fn):
|
87
|
+
def __init__(self, fn, names, init_fn, skip_first: bool = True):
|
89
88
|
super().__init__(fn)
|
90
89
|
self.names = names
|
91
90
|
self.init_fn = init_fn
|
91
|
+
self.skip_first = skip_first
|
92
92
|
|
93
93
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
94
94
|
vars = []
|
@@ -97,7 +97,7 @@ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the ge
|
|
97
97
|
st = state(p)
|
98
98
|
skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
|
99
99
|
vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
|
100
|
-
if skip_update:
|
100
|
+
if skip_update and self.skip_first:
|
101
101
|
raise SkipUpdate
|
102
102
|
return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
|
103
103
|
|
@@ -109,8 +109,17 @@ class NoState(FunctionTransform):
|
|
109
109
|
|
110
110
|
class NoStateNoForeach(FunctionTransform):
|
111
111
|
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
112
|
+
updates = []
|
113
|
+
skip_update = False
|
112
114
|
for a in zip(update, grad, param, *args):
|
113
|
-
|
115
|
+
try:
|
116
|
+
updates.append(self.fn(group, *a, **kwargs))
|
117
|
+
except SkipUpdate:
|
118
|
+
skip_update = True
|
119
|
+
pass
|
120
|
+
if skip_update:
|
121
|
+
raise SkipUpdate
|
122
|
+
return updates
|
114
123
|
|
115
124
|
|
116
125
|
def zero_guard(*names):
|
@@ -118,11 +127,11 @@ def zero_guard(*names):
|
|
118
127
|
|
119
128
|
|
120
129
|
def copy_guard(index, *names):
|
121
|
-
return functools.partial(CopyGuard, index=index, names=names)
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names,)
|
122
131
|
|
123
132
|
|
124
|
-
def general_guard(*names, init_fn):
|
125
|
-
return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
|
133
|
+
def general_guard(*names, init_fn, skip_first: bool = True):
|
134
|
+
return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first)
|
126
135
|
|
127
136
|
|
128
137
|
def no_state(fn):
|
@@ -307,21 +316,22 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
307
316
|
return precond
|
308
317
|
|
309
318
|
|
310
|
-
def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
319
|
+
def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
311
320
|
if prob is None:
|
312
321
|
prob = utils.precond_update_prob_schedule()
|
313
322
|
if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
|
314
|
-
return
|
323
|
+
return Q_mat
|
315
324
|
|
316
|
-
Q = [utils.promote(q_) for q_ in Q]
|
317
325
|
utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
|
318
326
|
|
319
|
-
if grad.dim() > 1 and precond_schedule(group, balance_probability, "
|
327
|
+
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
320
328
|
if group['store_triu_as_line']:
|
321
329
|
utils.psgd_balance_Q([q_ for _, q_ in Q])
|
322
330
|
else:
|
323
331
|
utils.psgd_balance_Q(Q)
|
324
332
|
|
333
|
+
return _update_psgd_cache(cached, Q_cache, Q_mat)
|
334
|
+
|
325
335
|
|
326
336
|
def _update_psgd_cache(cached, Q_cache, q):
|
327
337
|
if not cached:
|
@@ -350,44 +360,47 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
|
|
350
360
|
group['caution'], *Q_mat)
|
351
361
|
|
352
362
|
|
353
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
363
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
354
364
|
@no_state_no_foreach
|
355
365
|
def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
356
366
|
prob: Optional[callable] = None):
|
357
|
-
old = update
|
358
367
|
update = update.to(memory_format=torch.contiguous_format)
|
359
368
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
360
|
-
_update_psgd_precond(
|
369
|
+
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
370
|
+
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
361
371
|
return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
362
372
|
|
363
373
|
|
364
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
374
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
365
375
|
@no_state_no_foreach
|
366
376
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
367
377
|
prob: Optional[callable] = None):
|
368
378
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
369
379
|
precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
370
|
-
_update_psgd_precond(group, param, update
|
380
|
+
_update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
381
|
+
Q_mat, Q, exprs, prob)
|
371
382
|
return precond
|
372
383
|
|
373
384
|
|
374
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
385
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
375
386
|
@no_state_no_foreach
|
376
387
|
def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
377
388
|
prob: Optional[callable] = None):
|
378
389
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
379
|
-
_update_psgd_precond(
|
390
|
+
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
391
|
+
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
380
392
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
381
393
|
raise SkipUpdate
|
382
394
|
|
383
395
|
|
384
|
-
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
|
396
|
+
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
385
397
|
@no_state_no_foreach
|
386
398
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
387
399
|
prob: Optional[callable] = None):
|
388
400
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
389
401
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
390
|
-
_update_psgd_precond(group, param, update
|
402
|
+
_update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
403
|
+
Q_mat, Q, exprs, prob)
|
391
404
|
raise SkipUpdate
|
392
405
|
|
393
406
|
|
@@ -421,7 +434,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
421
434
|
|
422
435
|
|
423
436
|
class ChainOpt(utils.StatefulOptimizer):
|
424
|
-
|
437
|
+
promote: bool = False
|
425
438
|
|
426
439
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
427
440
|
super().__init__(params, defaults, foreach)
|
@@ -430,8 +443,12 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
430
443
|
def _step(self, group):
|
431
444
|
if 'base_lr' not in group:
|
432
445
|
group['base_lr'] = group['lr']
|
446
|
+
if 'prev_lr' in group and group['prev_lr'] != group['lr']:
|
447
|
+
utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and '
|
448
|
+
f'only supported with foreach=True (currently foreach={group["foreach"]}).')
|
449
|
+
group['base_lr'] = group['lr']
|
433
450
|
|
434
|
-
vals = list(self.split_p_and_g_in_group(group, should_promote=
|
451
|
+
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
435
452
|
if not vals:
|
436
453
|
return
|
437
454
|
p, g = zip(*vals)
|
@@ -449,9 +466,10 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
449
466
|
group['step'] = state['step'] = step = step + 1
|
450
467
|
|
451
468
|
if group['warmup_steps'] and step < group['warmup_steps']:
|
452
|
-
group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
469
|
+
group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
470
|
+
|
453
471
|
else:
|
454
|
-
group['lr'] = group['base_lr']
|
472
|
+
group['prev_lr'] = group['lr'] = group['base_lr']
|
455
473
|
|
456
474
|
if not group['foreach'] or len(p) == 1:
|
457
475
|
for param, grad in zip(p, g):
|
@@ -486,36 +504,72 @@ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd,
|
|
486
504
|
scale_by_adam.get_fn(): update_by_adam, #
|
487
505
|
scale_by_laprop.get_fn(): update_by_laprop, #
|
488
506
|
scale_by_adopt.get_fn(): update_by_adopt}
|
507
|
+
_scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
|
508
|
+
update_by_psgd.get_fn(): scale_by_psgd, #
|
509
|
+
update_by_adam.get_fn(): scale_by_adam, #
|
510
|
+
update_by_laprop.get_fn(): scale_by_laprop, #
|
511
|
+
update_by_adopt.get_fn(): scale_by_adopt}
|
489
512
|
|
490
513
|
|
491
514
|
class BaseOpt(ChainOpt):
|
515
|
+
"""
|
516
|
+
Base Optimizer
|
517
|
+
|
518
|
+
compile_step: bool = False
|
519
|
+
Whether to change some internals to try to make the optimizer compilable
|
520
|
+
This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
|
521
|
+
|
522
|
+
promote: bool = False
|
523
|
+
Whether to promote the gradients to fp32 before applying the optimizer
|
524
|
+
Improves update quality for low-precision parameters, but increases costs
|
525
|
+
Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
|
526
|
+
|
527
|
+
gradient_clipping: str_or_fn = None
|
528
|
+
The function to use for clipping the incoming gradients, before any other transformations.
|
529
|
+
This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
|
530
|
+
|
531
|
+
update_clipping: str_or_fn = None
|
532
|
+
The function to use for clipping the outgoing updates before applying them, after all other transformations.
|
533
|
+
This will turn off
|
534
|
+
This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
|
535
|
+
|
536
|
+
"""
|
537
|
+
|
492
538
|
gradient_clipping: str_or_fn = None
|
493
539
|
update_clipping: str_or_fn = None
|
494
540
|
palm: bool = False
|
495
541
|
auto_fuse: bool = True
|
496
542
|
|
497
543
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
498
|
-
palm: bool = use_default, *fns, compile_step: bool = use_default):
|
544
|
+
palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
|
545
|
+
if not fns:
|
546
|
+
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
547
|
+
|
548
|
+
args, kwargs = None, None
|
549
|
+
fn = fns[-1]
|
550
|
+
if isinstance(fn, functools.partial):
|
551
|
+
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
552
|
+
if isinstance(fn, FunctionTransform):
|
553
|
+
fn = fn.get_fn()
|
554
|
+
|
499
555
|
if default(update_clipping, self.update_clipping) is None:
|
500
|
-
if
|
501
|
-
args, kwargs = None, None
|
502
|
-
fn = fns[-1]
|
503
|
-
if isinstance(fn, functools.partial):
|
504
|
-
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
505
|
-
if isinstance(fn, FunctionTransform):
|
506
|
-
fn = fn.get_fn()
|
556
|
+
if self.auto_fuse:
|
507
557
|
if fn in _scale_to_update_map:
|
508
558
|
fn = _scale_to_update_map[fn]
|
509
559
|
if args is not None:
|
510
560
|
fn = functools.partial(fn, *args, **kwargs)
|
511
561
|
fns = tuple(fns)[:-1] + (fn,)
|
512
|
-
|
513
|
-
if
|
514
|
-
raise ValueError("
|
515
|
-
|
516
|
-
|
562
|
+
elif fn in _scale_to_update_map_inv:
|
563
|
+
if not self.auto_fuse:
|
564
|
+
raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
|
565
|
+
"Manually select scale_by_* functions or set auto_fuse=True.")
|
566
|
+
fn = _scale_to_update_map_inv[fn]
|
567
|
+
if args is not None:
|
568
|
+
fn = functools.partial(fn, *args, **kwargs)
|
569
|
+
fns = tuple(fns)[:-1] + (fn,)
|
517
570
|
|
518
571
|
self.compile_step = default(compile_step, self.compile_step)
|
572
|
+
self.promote = default(promote, self.promote)
|
519
573
|
if default(palm, self.palm):
|
520
574
|
fns = (palm_beta2,) + fns
|
521
575
|
if default(gradient_clipping, self.gradient_clipping) is not None:
|
@@ -193,12 +193,12 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
193
193
|
return grad
|
194
194
|
|
195
195
|
|
196
|
-
# TODO: This lerp was fucked - check other lerps
|
197
196
|
@decorator_knowngood
|
198
197
|
def _compilable_exp_avg_(state, grad, beta):
|
199
|
-
|
200
|
-
|
201
|
-
|
198
|
+
for s, g in zip(state, grad):
|
199
|
+
lerped = s.lerp(g, 1 - beta)
|
200
|
+
copy_stochastic_(s, lerped)
|
201
|
+
copy_stochastic_(g, lerped)
|
202
202
|
|
203
203
|
|
204
204
|
def scale_by_exp_avg_(state, grad, beta):
|
@@ -592,6 +592,7 @@ def project(grad, Q, back: bool):
|
|
592
592
|
|
593
593
|
class StatefulOptimizer(torch.optim.Optimizer):
|
594
594
|
ema_decay: float = 0.001
|
595
|
+
compile_step: bool = False
|
595
596
|
|
596
597
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
597
598
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
@@ -637,6 +638,10 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
637
638
|
|
638
639
|
p.grad = None
|
639
640
|
|
641
|
+
if self.compile_step:
|
642
|
+
yield p, grad
|
643
|
+
continue
|
644
|
+
|
640
645
|
p_views = merge_group(group, p)
|
641
646
|
if grad is not None:
|
642
647
|
grad = merge_group(group, grad)
|
@@ -1030,7 +1035,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
|
|
1030
1035
|
V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
|
1031
1036
|
eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
|
1032
1037
|
eps *= G.norm() / G.numel()
|
1033
|
-
G
|
1038
|
+
G = G + V * eps
|
1034
1039
|
md = min_dtype(Q + [G])
|
1035
1040
|
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
1036
1041
|
order = G.dim()
|
@@ -1078,88 +1083,115 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
|
1078
1083
|
term1 = promote(torch.einsum(exprG, A, A))
|
1079
1084
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1080
1085
|
|
1081
|
-
term2
|
1082
|
-
term1 *= 2 # 2a
|
1083
|
-
if term1.dtype == term2.dtype:
|
1084
|
-
term1 -= term2 # 2a - (a + b) == a - b
|
1085
|
-
else:
|
1086
|
-
term1 = term1 - term2
|
1086
|
+
term1, term2 = term1 - term2, term1 + term2
|
1087
1087
|
|
1088
1088
|
term1 *= precond_lr
|
1089
1089
|
norm = term2.norm(float('inf'))
|
1090
1090
|
if q.dim() < 2:
|
1091
|
-
term1 *= q.to(term1.dtype)
|
1092
|
-
term1 /= norm.clamp_(min=tiny_bf16)
|
1091
|
+
term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
|
1093
1092
|
else:
|
1094
1093
|
torch.triu(term1, out=term1)
|
1095
|
-
term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
|
1096
|
-
torch.
|
1094
|
+
term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
|
1095
|
+
term1 = torch.mm(term1, q)
|
1097
1096
|
if store_triu_as_line:
|
1098
1097
|
term1 = triu_to_line([term1])[0][1]
|
1099
1098
|
o = o[1]
|
1100
|
-
stochastic_add_(
|
1099
|
+
stochastic_add_(o, term1, -1)
|
1101
1100
|
|
1102
1101
|
|
1103
1102
|
@decorator_knowngood
|
1104
|
-
def _compilable_l2_clip_(x):
|
1103
|
+
def _compilable_l2_clip_(x, clip_at):
|
1105
1104
|
ref = x
|
1106
1105
|
x = list(map(promote, x))
|
1107
1106
|
norm = torch._foreach_norm(x)
|
1108
|
-
torch._foreach_maximum_(norm,
|
1107
|
+
torch._foreach_maximum_(norm, clip_at)
|
1109
1108
|
out = torch._foreach_div(x, norm)
|
1110
1109
|
return stochastic_round_list_(ref, out)
|
1111
1110
|
|
1112
1111
|
|
1113
|
-
def
|
1112
|
+
def l2_normalization_(x, clip_at: float = 1e-8):
|
1114
1113
|
x = list_guard(x)
|
1115
|
-
return _compilable_l2_clip_(x)
|
1114
|
+
return _compilable_l2_clip_(x, clip_at)
|
1115
|
+
|
1116
|
+
|
1117
|
+
def l2_clip_(x, clip_at: float = 1.):
|
1118
|
+
x = list_guard(x)
|
1119
|
+
return _compilable_l2_clip_(x, clip_at)
|
1116
1120
|
|
1117
1121
|
|
1118
1122
|
@decorator_knowngood
|
1119
|
-
def _compilable_rmsnorm_clip_(x):
|
1123
|
+
def _compilable_rmsnorm_clip_(x, clip_at):
|
1120
1124
|
x = list(map(promote, x))
|
1121
1125
|
norm = torch._foreach_norm(x)
|
1122
1126
|
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1123
|
-
torch._foreach_maximum_(norm,
|
1127
|
+
torch._foreach_maximum_(norm, clip_at)
|
1124
1128
|
return torch._foreach_div(x, norm)
|
1125
1129
|
|
1126
1130
|
|
1127
|
-
def rmsnorm_clip_(x):
|
1131
|
+
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
1128
1132
|
x = list_guard(x)
|
1129
|
-
return _compilable_rmsnorm_clip_(x)
|
1133
|
+
return _compilable_rmsnorm_clip_(x, clip_at)
|
1130
1134
|
|
1131
1135
|
|
1132
|
-
def
|
1136
|
+
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1137
|
+
x = list_guard(x)
|
1138
|
+
return _compilable_rmsnorm_clip_(x, clip_at)
|
1139
|
+
|
1140
|
+
|
1141
|
+
@decorator_knowngood
|
1142
|
+
def _compilable_mu_law_compress_(x, mu):
|
1133
1143
|
"""
|
1134
|
-
|
1144
|
+
original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1145
|
+
"""
|
1146
|
+
|
1147
|
+
for x_ in x:
|
1148
|
+
xa = promote(x_.abs()) * mu
|
1149
|
+
xa = xa.log1p()
|
1150
|
+
xa = xa / math.log1p(mu)
|
1151
|
+
xa = xa.copysign(x_)
|
1152
|
+
copy_stochastic_(x_, xa)
|
1135
1153
|
|
1154
|
+
|
1155
|
+
def mu_law_compress(x, mu=127.0):
|
1156
|
+
"""
|
1136
1157
|
μ-law compression
|
1137
1158
|
Args:
|
1138
1159
|
x: Input tensor
|
1139
1160
|
mu: Compression parameter (default 127.0 for behavior similar to trust_region=1.5)
|
1140
1161
|
"""
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
return [xa_.copysign_(x_) for x_, xa_ in zip(x, xa)]
|
1162
|
+
x = list_guard(x)
|
1163
|
+
mu = scalar_guard(mu, x[0])
|
1164
|
+
_compilable_mu_law_compress_(x, mu)
|
1165
|
+
return x
|
1146
1166
|
|
1147
1167
|
|
1148
|
-
|
1168
|
+
@decorator_knowngood
|
1169
|
+
def _compilable_a_law_compress_(x, A):
|
1170
|
+
"""
|
1171
|
+
original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1149
1172
|
"""
|
1150
|
-
|
1173
|
+
for x_ in x:
|
1174
|
+
xa = promote(x_.abs()) * A
|
1175
|
+
xa = torch.where(xa < 1, xa, 1 + xa.log())
|
1176
|
+
xa = xa.copysign(x_)
|
1177
|
+
xa = xa * (1 / (1 + math.log(A)))
|
1178
|
+
copy_stochastic_(x_, xa)
|
1151
1179
|
|
1180
|
+
|
1181
|
+
def a_law_compress(x, A=87.6):
|
1182
|
+
"""
|
1152
1183
|
A-law compression
|
1153
1184
|
Args:
|
1154
1185
|
x: Input tensor
|
1155
1186
|
A: Compression parameter (default 87.6 - European PCM standard)
|
1187
|
+
:param x:
|
1188
|
+
:param A:
|
1189
|
+
:return:
|
1156
1190
|
"""
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
|
1162
|
-
return xa
|
1191
|
+
x = list_guard(x)
|
1192
|
+
A = scalar_guard(A, x[0])
|
1193
|
+
_compilable_a_law_compress_(x, A)
|
1194
|
+
return x
|
1163
1195
|
|
1164
1196
|
|
1165
1197
|
def identity(x):
|
@@ -1167,24 +1199,24 @@ def identity(x):
|
|
1167
1199
|
|
1168
1200
|
|
1169
1201
|
@decorator_knowngood
|
1170
|
-
def _compilable_trust_region_clip_(grad, lerp
|
1202
|
+
def _compilable_trust_region_clip_(grad, lerp, scale):
|
1171
1203
|
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
|
1204
|
+
for x_ in grad:
|
1205
|
+
x = promote(x_)
|
1206
|
+
x = x / scale
|
1207
|
+
tanh = x.tanh()
|
1208
|
+
x = x.abs().log1p()
|
1209
|
+
x = x.copysign(tanh) * (1 - lerp) + tanh * lerp
|
1210
|
+
x = x * scale
|
1211
|
+
x = x.clamp(min=-2, max=2)
|
1212
|
+
copy_stochastic_(x_, x)
|
1182
1213
|
|
1183
1214
|
|
1184
1215
|
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1185
1216
|
grad = list_guard(grad)
|
1186
1217
|
lerp, scale = scalar_guard(lerp, scale, grad[0])
|
1187
|
-
|
1218
|
+
_compilable_trust_region_clip_(grad, lerp, scale)
|
1219
|
+
return grad
|
1188
1220
|
|
1189
1221
|
|
1190
1222
|
@decorator
|
@@ -1262,6 +1294,7 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
|
|
1262
1294
|
|
1263
1295
|
|
1264
1296
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1297
|
+
|
1265
1298
|
lr = scalar_guard(lr, param[0])
|
1266
1299
|
_compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
|
1267
1300
|
|
@@ -1277,7 +1310,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
|
1277
1310
|
|
1278
1311
|
@decorator_knowngood
|
1279
1312
|
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1280
|
-
precond = psgd_precond_grad(expr,
|
1313
|
+
precond = psgd_precond_grad(expr, ea, *preconds)
|
1281
1314
|
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
|
1282
1315
|
|
1283
1316
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|