heavyball 1.3.0__tar.gz → 1.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.3.0 → heavyball-1.3.1}/PKG-INFO +1 -1
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball/chainable.py +57 -19
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball/utils.py +68 -35
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.3.0 → heavyball-1.3.1}/setup.py +1 -1
- {heavyball-1.3.0 → heavyball-1.3.1}/LICENSE +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/README.md +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball/__init__.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/setup.cfg +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_bf16_params.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_bf16_q.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_bf16_storage.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_caution.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_channels_last.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_closure.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_ema.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_foreach.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_hook.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_mars.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_memory.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_merge.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_no_grad.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_psgd.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_soap.py +0 -0
- {heavyball-1.3.0 → heavyball-1.3.1}/test/test_stochastic_updates.py +0 -0
@@ -307,7 +307,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
307
307
|
return precond
|
308
308
|
|
309
309
|
|
310
|
-
def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
310
|
+
def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
311
311
|
if prob is None:
|
312
312
|
prob = utils.precond_update_prob_schedule()
|
313
313
|
if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
|
@@ -322,6 +322,7 @@ def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[cal
|
|
322
322
|
else:
|
323
323
|
utils.psgd_balance_Q(Q)
|
324
324
|
|
325
|
+
_update_psgd_cache(cached, Q_cache, Q_mat)
|
325
326
|
|
326
327
|
def _update_psgd_cache(cached, Q_cache, q):
|
327
328
|
if not cached:
|
@@ -357,7 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
357
358
|
old = update
|
358
359
|
update = update.to(memory_format=torch.contiguous_format)
|
359
360
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
360
|
-
_update_psgd_precond(group, param,
|
361
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
361
362
|
return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
362
363
|
|
363
364
|
|
@@ -367,7 +368,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
|
|
367
368
|
prob: Optional[callable] = None):
|
368
369
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
369
370
|
precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
370
|
-
_update_psgd_precond(group, param,
|
371
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
371
372
|
return precond
|
372
373
|
|
373
374
|
|
@@ -376,7 +377,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
|
|
376
377
|
def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
377
378
|
prob: Optional[callable] = None):
|
378
379
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
379
|
-
_update_psgd_precond(group, param,
|
380
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
380
381
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
381
382
|
raise SkipUpdate
|
382
383
|
|
@@ -387,7 +388,7 @@ def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_
|
|
387
388
|
prob: Optional[callable] = None):
|
388
389
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
389
390
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
390
|
-
_update_psgd_precond(group, param,
|
391
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
391
392
|
raise SkipUpdate
|
392
393
|
|
393
394
|
|
@@ -422,6 +423,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
422
423
|
|
423
424
|
class ChainOpt(utils.StatefulOptimizer):
|
424
425
|
compile_step: bool = False
|
426
|
+
promote: bool = False
|
425
427
|
|
426
428
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
427
429
|
super().__init__(params, defaults, foreach)
|
@@ -431,7 +433,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
431
433
|
if 'base_lr' not in group:
|
432
434
|
group['base_lr'] = group['lr']
|
433
435
|
|
434
|
-
vals = list(self.split_p_and_g_in_group(group, should_promote=
|
436
|
+
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
435
437
|
if not vals:
|
436
438
|
return
|
437
439
|
p, g = zip(*vals)
|
@@ -486,36 +488,72 @@ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd,
|
|
486
488
|
scale_by_adam.get_fn(): update_by_adam, #
|
487
489
|
scale_by_laprop.get_fn(): update_by_laprop, #
|
488
490
|
scale_by_adopt.get_fn(): update_by_adopt}
|
491
|
+
_scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
|
492
|
+
update_by_psgd.get_fn(): scale_by_psgd, #
|
493
|
+
update_by_adam.get_fn(): scale_by_adam, #
|
494
|
+
update_by_laprop.get_fn(): scale_by_laprop, #
|
495
|
+
update_by_adopt.get_fn(): scale_by_adopt}
|
489
496
|
|
490
497
|
|
491
498
|
class BaseOpt(ChainOpt):
|
499
|
+
"""
|
500
|
+
Base Optimizer
|
501
|
+
|
502
|
+
compile_step: bool = False
|
503
|
+
Whether to change some internals to try to make the optimizer compilable
|
504
|
+
This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
|
505
|
+
|
506
|
+
promote: bool = False
|
507
|
+
Whether to promote the gradients to fp32 before applying the optimizer
|
508
|
+
Improves update quality for low-precision parameters, but increases costs
|
509
|
+
Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
|
510
|
+
|
511
|
+
gradient_clipping: str_or_fn = None
|
512
|
+
The function to use for clipping the incoming gradients, before any other transformations.
|
513
|
+
This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
|
514
|
+
|
515
|
+
update_clipping: str_or_fn = None
|
516
|
+
The function to use for clipping the outgoing updates before applying them, after all other transformations.
|
517
|
+
This will turn off
|
518
|
+
This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
|
519
|
+
|
520
|
+
"""
|
521
|
+
|
492
522
|
gradient_clipping: str_or_fn = None
|
493
523
|
update_clipping: str_or_fn = None
|
494
524
|
palm: bool = False
|
495
525
|
auto_fuse: bool = True
|
496
526
|
|
497
527
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
498
|
-
palm: bool = use_default, *fns, compile_step: bool = use_default):
|
528
|
+
palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
|
529
|
+
if not fns:
|
530
|
+
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
531
|
+
|
532
|
+
args, kwargs = None, None
|
533
|
+
fn = fns[-1]
|
534
|
+
if isinstance(fn, functools.partial):
|
535
|
+
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
536
|
+
if isinstance(fn, FunctionTransform):
|
537
|
+
fn = fn.get_fn()
|
538
|
+
|
499
539
|
if default(update_clipping, self.update_clipping) is None:
|
500
|
-
if
|
501
|
-
args, kwargs = None, None
|
502
|
-
fn = fns[-1]
|
503
|
-
if isinstance(fn, functools.partial):
|
504
|
-
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
505
|
-
if isinstance(fn, FunctionTransform):
|
506
|
-
fn = fn.get_fn()
|
540
|
+
if self.auto_fuse:
|
507
541
|
if fn in _scale_to_update_map:
|
508
542
|
fn = _scale_to_update_map[fn]
|
509
543
|
if args is not None:
|
510
544
|
fn = functools.partial(fn, *args, **kwargs)
|
511
545
|
fns = tuple(fns)[:-1] + (fn,)
|
512
|
-
|
513
|
-
if
|
514
|
-
raise ValueError("
|
515
|
-
|
516
|
-
|
546
|
+
elif fn in _scale_to_update_map_inv:
|
547
|
+
if not self.auto_fuse:
|
548
|
+
raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
|
549
|
+
"Manually select scale_by_* functions or set auto_fuse=True.")
|
550
|
+
fn = _scale_to_update_map_inv[fn]
|
551
|
+
if args is not None:
|
552
|
+
fn = functools.partial(fn, *args, **kwargs)
|
553
|
+
fns = tuple(fns)[:-1] + (fn,)
|
517
554
|
|
518
555
|
self.compile_step = default(compile_step, self.compile_step)
|
556
|
+
self.promote = default(promote, self.promote)
|
519
557
|
if default(palm, self.palm):
|
520
558
|
fns = (palm_beta2,) + fns
|
521
559
|
if default(gradient_clipping, self.gradient_clipping) is not None:
|
@@ -1101,65 +1101,98 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
|
1101
1101
|
|
1102
1102
|
|
1103
1103
|
@decorator_knowngood
|
1104
|
-
def _compilable_l2_clip_(x):
|
1104
|
+
def _compilable_l2_clip_(x, clip_at):
|
1105
1105
|
ref = x
|
1106
1106
|
x = list(map(promote, x))
|
1107
1107
|
norm = torch._foreach_norm(x)
|
1108
|
-
torch._foreach_maximum_(norm,
|
1108
|
+
torch._foreach_maximum_(norm, clip_at)
|
1109
1109
|
out = torch._foreach_div(x, norm)
|
1110
1110
|
return stochastic_round_list_(ref, out)
|
1111
1111
|
|
1112
1112
|
|
1113
|
-
def
|
1113
|
+
def l2_normalization_(x, clip_at: float = 1e-8):
|
1114
1114
|
x = list_guard(x)
|
1115
|
-
return _compilable_l2_clip_(x)
|
1115
|
+
return _compilable_l2_clip_(x, clip_at)
|
1116
|
+
|
1117
|
+
|
1118
|
+
def l2_clip_(x, clip_at: float = 1.):
|
1119
|
+
x = list_guard(x)
|
1120
|
+
return _compilable_l2_clip_(x, clip_at)
|
1116
1121
|
|
1117
1122
|
|
1118
1123
|
@decorator_knowngood
|
1119
|
-
def _compilable_rmsnorm_clip_(x):
|
1124
|
+
def _compilable_rmsnorm_clip_(x, clip_at):
|
1120
1125
|
x = list(map(promote, x))
|
1121
1126
|
norm = torch._foreach_norm(x)
|
1122
1127
|
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1123
|
-
torch._foreach_maximum_(norm,
|
1128
|
+
torch._foreach_maximum_(norm, clip_at)
|
1124
1129
|
return torch._foreach_div(x, norm)
|
1125
1130
|
|
1126
1131
|
|
1127
|
-
def rmsnorm_clip_(x):
|
1132
|
+
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
1128
1133
|
x = list_guard(x)
|
1129
|
-
return _compilable_rmsnorm_clip_(x)
|
1134
|
+
return _compilable_rmsnorm_clip_(x, clip_at)
|
1130
1135
|
|
1131
1136
|
|
1132
|
-
def
|
1137
|
+
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1138
|
+
x = list_guard(x)
|
1139
|
+
return _compilable_rmsnorm_clip_(x, clip_at)
|
1140
|
+
|
1141
|
+
|
1142
|
+
@decorator_knowngood
|
1143
|
+
def _compilable_mu_law_compress_(x, mu):
|
1144
|
+
"""
|
1145
|
+
original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1133
1146
|
"""
|
1134
|
-
Foreach version of https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1135
1147
|
|
1148
|
+
for x_ in x:
|
1149
|
+
xa = promote(x_.abs()) * mu
|
1150
|
+
xa = xa.log1p()
|
1151
|
+
xa = xa / math.log1p(mu)
|
1152
|
+
xa = xa.copysign(x_)
|
1153
|
+
copy_stochastic_(x_, xa)
|
1154
|
+
|
1155
|
+
|
1156
|
+
def mu_law_compress(x, mu=127.0):
|
1157
|
+
"""
|
1136
1158
|
μ-law compression
|
1137
1159
|
Args:
|
1138
1160
|
x: Input tensor
|
1139
1161
|
mu: Compression parameter (default 127.0 for behavior similar to trust_region=1.5)
|
1140
1162
|
"""
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
return [xa_.copysign_(x_) for x_, xa_ in zip(x, xa)]
|
1163
|
+
x = list_guard(x)
|
1164
|
+
mu = scalar_guard(mu, x[0])
|
1165
|
+
_compilable_mu_law_compress(x, mu)
|
1166
|
+
return x
|
1146
1167
|
|
1147
1168
|
|
1148
|
-
|
1169
|
+
@decorator_knowngood
|
1170
|
+
def _compilable_a_law_compress_(x, A):
|
1171
|
+
"""
|
1172
|
+
original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1149
1173
|
"""
|
1150
|
-
|
1174
|
+
for x_ in x:
|
1175
|
+
xa = promote(x_.abs()) * A
|
1176
|
+
xa = torch.where(xa < 1, xa, 1 + xa.log())
|
1177
|
+
xa = xa.copysign(x_)
|
1178
|
+
xa = xa * (1 / (1 + math.log(A)))
|
1179
|
+
copy_stochastic_(x_, xa)
|
1151
1180
|
|
1181
|
+
|
1182
|
+
def a_law_compress(x, A=87.6):
|
1183
|
+
"""
|
1152
1184
|
A-law compression
|
1153
1185
|
Args:
|
1154
1186
|
x: Input tensor
|
1155
1187
|
A: Compression parameter (default 87.6 - European PCM standard)
|
1188
|
+
:param x:
|
1189
|
+
:param A:
|
1190
|
+
:return:
|
1156
1191
|
"""
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
|
1162
|
-
return xa
|
1192
|
+
x = list_guard(x)
|
1193
|
+
A = scalar_guard(A, x[0])
|
1194
|
+
_compilable_a_law_compress(x, A)
|
1195
|
+
return x
|
1163
1196
|
|
1164
1197
|
|
1165
1198
|
def identity(x):
|
@@ -1167,24 +1200,24 @@ def identity(x):
|
|
1167
1200
|
|
1168
1201
|
|
1169
1202
|
@decorator_knowngood
|
1170
|
-
def _compilable_trust_region_clip_(grad, lerp
|
1203
|
+
def _compilable_trust_region_clip_(grad, lerp, scale):
|
1171
1204
|
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
|
1205
|
+
for x_ in grad:
|
1206
|
+
x = promote(x_)
|
1207
|
+
x = x / scale
|
1208
|
+
tanh = x.tanh()
|
1209
|
+
x = x.abs().log1p()
|
1210
|
+
x = x.copysign(tanh) * (1 - lerp) + tanh * lerp
|
1211
|
+
x = x * scale
|
1212
|
+
x = x.clamp(min=-2, max=2)
|
1213
|
+
copy_stochastic_(x_, x)
|
1182
1214
|
|
1183
1215
|
|
1184
1216
|
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1185
1217
|
grad = list_guard(grad)
|
1186
1218
|
lerp, scale = scalar_guard(lerp, scale, grad[0])
|
1187
|
-
|
1219
|
+
_compilable_trust_region_clip_(grad, lerp, scale)
|
1220
|
+
return grad
|
1188
1221
|
|
1189
1222
|
|
1190
1223
|
@decorator
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|