heavyball 1.2.3__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/chainable.py +64 -25
- heavyball/utils.py +74 -42
- {heavyball-1.2.3.dist-info → heavyball-1.3.1.dist-info}/METADATA +1 -1
- heavyball-1.3.1.dist-info/RECORD +8 -0
- heavyball-1.2.3.dist-info/RECORD +0 -8
- {heavyball-1.2.3.dist-info → heavyball-1.3.1.dist-info}/LICENSE +0 -0
- {heavyball-1.2.3.dist-info → heavyball-1.3.1.dist-info}/WHEEL +0 -0
- {heavyball-1.2.3.dist-info → heavyball-1.3.1.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -307,7 +307,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
307
307
|
return precond
|
308
308
|
|
309
309
|
|
310
|
-
def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
310
|
+
def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
|
311
311
|
if prob is None:
|
312
312
|
prob = utils.precond_update_prob_schedule()
|
313
313
|
if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
|
@@ -322,6 +322,7 @@ def _update_psgd_precond(group, param, grad, Q_mat, Q, exprs, prob: Optional[cal
|
|
322
322
|
else:
|
323
323
|
utils.psgd_balance_Q(Q)
|
324
324
|
|
325
|
+
_update_psgd_cache(cached, Q_cache, Q_mat)
|
325
326
|
|
326
327
|
def _update_psgd_cache(cached, Q_cache, q):
|
327
328
|
if not cached:
|
@@ -357,7 +358,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
357
358
|
old = update
|
358
359
|
update = update.to(memory_format=torch.contiguous_format)
|
359
360
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
360
|
-
_update_psgd_precond(group, param,
|
361
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
361
362
|
return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
362
363
|
|
363
364
|
|
@@ -367,7 +368,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
|
|
367
368
|
prob: Optional[callable] = None):
|
368
369
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
369
370
|
precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
370
|
-
_update_psgd_precond(group, param,
|
371
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
371
372
|
return precond
|
372
373
|
|
373
374
|
|
@@ -376,7 +377,7 @@ def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_e
|
|
376
377
|
def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
377
378
|
prob: Optional[callable] = None):
|
378
379
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
379
|
-
_update_psgd_precond(group, param,
|
380
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
380
381
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
381
382
|
raise SkipUpdate
|
382
383
|
|
@@ -387,7 +388,7 @@ def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_
|
|
387
388
|
prob: Optional[callable] = None):
|
388
389
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
389
390
|
_fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
390
|
-
_update_psgd_precond(group, param,
|
391
|
+
_update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
|
391
392
|
raise SkipUpdate
|
392
393
|
|
393
394
|
|
@@ -422,6 +423,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
422
423
|
|
423
424
|
class ChainOpt(utils.StatefulOptimizer):
|
424
425
|
compile_step: bool = False
|
426
|
+
promote: bool = False
|
425
427
|
|
426
428
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
427
429
|
super().__init__(params, defaults, foreach)
|
@@ -431,21 +433,22 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
431
433
|
if 'base_lr' not in group:
|
432
434
|
group['base_lr'] = group['lr']
|
433
435
|
|
434
|
-
vals = list(self.split_p_and_g_in_group(group, should_promote=
|
436
|
+
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
435
437
|
if not vals:
|
436
438
|
return
|
437
439
|
p, g = zip(*vals)
|
438
440
|
|
439
441
|
for param in p:
|
440
442
|
state = self.state_(param)
|
441
|
-
if 'step'
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
443
|
+
if 'step' in state:
|
444
|
+
step = state['step']
|
445
|
+
elif self.compile_step:
|
446
|
+
step = utils.scalar_guard(0, param)
|
447
|
+
else:
|
448
|
+
step = 0
|
446
449
|
break
|
447
450
|
|
448
|
-
group['step'] = step
|
451
|
+
group['step'] = state['step'] = step = step + 1
|
449
452
|
|
450
453
|
if group['warmup_steps'] and step < group['warmup_steps']:
|
451
454
|
group['lr'] = group['base_lr'] * step / group['warmup_steps']
|
@@ -485,36 +488,72 @@ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd,
|
|
485
488
|
scale_by_adam.get_fn(): update_by_adam, #
|
486
489
|
scale_by_laprop.get_fn(): update_by_laprop, #
|
487
490
|
scale_by_adopt.get_fn(): update_by_adopt}
|
491
|
+
_scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
|
492
|
+
update_by_psgd.get_fn(): scale_by_psgd, #
|
493
|
+
update_by_adam.get_fn(): scale_by_adam, #
|
494
|
+
update_by_laprop.get_fn(): scale_by_laprop, #
|
495
|
+
update_by_adopt.get_fn(): scale_by_adopt}
|
488
496
|
|
489
497
|
|
490
498
|
class BaseOpt(ChainOpt):
|
499
|
+
"""
|
500
|
+
Base Optimizer
|
501
|
+
|
502
|
+
compile_step: bool = False
|
503
|
+
Whether to change some internals to try to make the optimizer compilable
|
504
|
+
This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
|
505
|
+
|
506
|
+
promote: bool = False
|
507
|
+
Whether to promote the gradients to fp32 before applying the optimizer
|
508
|
+
Improves update quality for low-precision parameters, but increases costs
|
509
|
+
Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
|
510
|
+
|
511
|
+
gradient_clipping: str_or_fn = None
|
512
|
+
The function to use for clipping the incoming gradients, before any other transformations.
|
513
|
+
This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
|
514
|
+
|
515
|
+
update_clipping: str_or_fn = None
|
516
|
+
The function to use for clipping the outgoing updates before applying them, after all other transformations.
|
517
|
+
This will turn off
|
518
|
+
This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
|
519
|
+
|
520
|
+
"""
|
521
|
+
|
491
522
|
gradient_clipping: str_or_fn = None
|
492
523
|
update_clipping: str_or_fn = None
|
493
524
|
palm: bool = False
|
494
525
|
auto_fuse: bool = True
|
495
526
|
|
496
527
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
497
|
-
palm: bool = use_default, compile_step: bool = use_default,
|
528
|
+
palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
|
529
|
+
if not fns:
|
530
|
+
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
531
|
+
|
532
|
+
args, kwargs = None, None
|
533
|
+
fn = fns[-1]
|
534
|
+
if isinstance(fn, functools.partial):
|
535
|
+
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
536
|
+
if isinstance(fn, FunctionTransform):
|
537
|
+
fn = fn.get_fn()
|
538
|
+
|
498
539
|
if default(update_clipping, self.update_clipping) is None:
|
499
|
-
if
|
500
|
-
args, kwargs = None, None
|
501
|
-
fn = fns[-1]
|
502
|
-
if isinstance(fn, functools.partial):
|
503
|
-
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
504
|
-
if isinstance(fn, FunctionTransform):
|
505
|
-
fn = fn.get_fn()
|
540
|
+
if self.auto_fuse:
|
506
541
|
if fn in _scale_to_update_map:
|
507
542
|
fn = _scale_to_update_map[fn]
|
508
543
|
if args is not None:
|
509
544
|
fn = functools.partial(fn, *args, **kwargs)
|
510
545
|
fns = tuple(fns)[:-1] + (fn,)
|
511
|
-
|
512
|
-
if
|
513
|
-
raise ValueError("
|
514
|
-
|
515
|
-
|
546
|
+
elif fn in _scale_to_update_map_inv:
|
547
|
+
if not self.auto_fuse:
|
548
|
+
raise ValueError("update_clipping is currently not compatible with update_by_* functions. "
|
549
|
+
"Manually select scale_by_* functions or set auto_fuse=True.")
|
550
|
+
fn = _scale_to_update_map_inv[fn]
|
551
|
+
if args is not None:
|
552
|
+
fn = functools.partial(fn, *args, **kwargs)
|
553
|
+
fns = tuple(fns)[:-1] + (fn,)
|
516
554
|
|
517
555
|
self.compile_step = default(compile_step, self.compile_step)
|
556
|
+
self.promote = default(promote, self.promote)
|
518
557
|
if default(palm, self.palm):
|
519
558
|
fns = (palm_beta2,) + fns
|
520
559
|
if default(gradient_clipping, self.gradient_clipping) is not None:
|
heavyball/utils.py
CHANGED
@@ -88,7 +88,7 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
|
|
88
88
|
except ZeroDivisionError:
|
89
89
|
ckp1 = 0
|
90
90
|
|
91
|
-
update, parameters, z = list_guard(update, parameters, z)
|
91
|
+
update, parameters, z, grad = list_guard(update, parameters, z, grad)
|
92
92
|
lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
|
93
93
|
_compilable_schedule_free_(parameters, z, ckp1, update, lr, beta1, decay, grad, caution)
|
94
94
|
return weight_sum
|
@@ -912,13 +912,12 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
912
912
|
@decorator_knowngood
|
913
913
|
def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
|
914
914
|
g: List[Optional[Tensor]]):
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
|
915
|
+
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
916
|
+
u_ = promote(u_.view_as(p_))
|
917
|
+
p32_ = promote(p_)
|
919
918
|
if caution:
|
920
|
-
|
921
|
-
p32_ = p32_ * (1 - decay * lr) +
|
919
|
+
u_ = _compilable_cautioning(promote(g_), u_)
|
920
|
+
p32_ = p32_ * (1 - decay * lr) + u_ * -lr
|
922
921
|
copy_stochastic_(p_, p32_)
|
923
922
|
|
924
923
|
|
@@ -1102,65 +1101,98 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
|
|
1102
1101
|
|
1103
1102
|
|
1104
1103
|
@decorator_knowngood
|
1105
|
-
def _compilable_l2_clip_(x):
|
1104
|
+
def _compilable_l2_clip_(x, clip_at):
|
1106
1105
|
ref = x
|
1107
1106
|
x = list(map(promote, x))
|
1108
1107
|
norm = torch._foreach_norm(x)
|
1109
|
-
torch._foreach_maximum_(norm,
|
1108
|
+
torch._foreach_maximum_(norm, clip_at)
|
1110
1109
|
out = torch._foreach_div(x, norm)
|
1111
1110
|
return stochastic_round_list_(ref, out)
|
1112
1111
|
|
1113
1112
|
|
1114
|
-
def
|
1113
|
+
def l2_normalization_(x, clip_at: float = 1e-8):
|
1114
|
+
x = list_guard(x)
|
1115
|
+
return _compilable_l2_clip_(x, clip_at)
|
1116
|
+
|
1117
|
+
|
1118
|
+
def l2_clip_(x, clip_at: float = 1.):
|
1115
1119
|
x = list_guard(x)
|
1116
|
-
return _compilable_l2_clip_(x)
|
1120
|
+
return _compilable_l2_clip_(x, clip_at)
|
1117
1121
|
|
1118
1122
|
|
1119
1123
|
@decorator_knowngood
|
1120
|
-
def _compilable_rmsnorm_clip_(x):
|
1124
|
+
def _compilable_rmsnorm_clip_(x, clip_at):
|
1121
1125
|
x = list(map(promote, x))
|
1122
1126
|
norm = torch._foreach_norm(x)
|
1123
1127
|
norm = [n.div_(x_.numel() ** 0.5) for n, x_ in zip(norm, x)]
|
1124
|
-
torch._foreach_maximum_(norm,
|
1128
|
+
torch._foreach_maximum_(norm, clip_at)
|
1125
1129
|
return torch._foreach_div(x, norm)
|
1126
1130
|
|
1127
1131
|
|
1128
|
-
def rmsnorm_clip_(x):
|
1132
|
+
def rmsnorm_clip_(x, clip_at: float = 1.0):
|
1129
1133
|
x = list_guard(x)
|
1130
|
-
return _compilable_rmsnorm_clip_(x)
|
1134
|
+
return _compilable_rmsnorm_clip_(x, clip_at)
|
1131
1135
|
|
1132
1136
|
|
1133
|
-
def
|
1137
|
+
def rmsnorm_normalize_(x, clip_at: float = 1e-6):
|
1138
|
+
x = list_guard(x)
|
1139
|
+
return _compilable_rmsnorm_clip_(x, clip_at)
|
1140
|
+
|
1141
|
+
|
1142
|
+
@decorator_knowngood
|
1143
|
+
def _compilable_mu_law_compress_(x, mu):
|
1134
1144
|
"""
|
1135
|
-
|
1145
|
+
original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1146
|
+
"""
|
1147
|
+
|
1148
|
+
for x_ in x:
|
1149
|
+
xa = promote(x_.abs()) * mu
|
1150
|
+
xa = xa.log1p()
|
1151
|
+
xa = xa / math.log1p(mu)
|
1152
|
+
xa = xa.copysign(x_)
|
1153
|
+
copy_stochastic_(x_, xa)
|
1136
1154
|
|
1155
|
+
|
1156
|
+
def mu_law_compress(x, mu=127.0):
|
1157
|
+
"""
|
1137
1158
|
μ-law compression
|
1138
1159
|
Args:
|
1139
1160
|
x: Input tensor
|
1140
1161
|
mu: Compression parameter (default 127.0 for behavior similar to trust_region=1.5)
|
1141
1162
|
"""
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
return [xa_.copysign_(x_) for x_, xa_ in zip(x, xa)]
|
1163
|
+
x = list_guard(x)
|
1164
|
+
mu = scalar_guard(mu, x[0])
|
1165
|
+
_compilable_mu_law_compress(x, mu)
|
1166
|
+
return x
|
1147
1167
|
|
1148
1168
|
|
1149
|
-
|
1169
|
+
@decorator_knowngood
|
1170
|
+
def _compilable_a_law_compress_(x, A):
|
1150
1171
|
"""
|
1151
|
-
|
1172
|
+
original at https://github.com/opooladz/modded-nanogpt-psgd/blob/dc7c78082ac15fbf326f1bacd9e0ead0a2b45908/kron_mu.py
|
1173
|
+
"""
|
1174
|
+
for x_ in x:
|
1175
|
+
xa = promote(x_.abs()) * A
|
1176
|
+
xa = torch.where(xa < 1, xa, 1 + xa.log())
|
1177
|
+
xa = xa.copysign(x_)
|
1178
|
+
xa = xa * (1 / (1 + math.log(A)))
|
1179
|
+
copy_stochastic_(x_, xa)
|
1180
|
+
|
1152
1181
|
|
1182
|
+
def a_law_compress(x, A=87.6):
|
1183
|
+
"""
|
1153
1184
|
A-law compression
|
1154
1185
|
Args:
|
1155
1186
|
x: Input tensor
|
1156
1187
|
A: Compression parameter (default 87.6 - European PCM standard)
|
1188
|
+
:param x:
|
1189
|
+
:param A:
|
1190
|
+
:return:
|
1157
1191
|
"""
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
torch._foreach_mul_(xa, 1 / (1 + math.log(A)))
|
1163
|
-
return xa
|
1192
|
+
x = list_guard(x)
|
1193
|
+
A = scalar_guard(A, x[0])
|
1194
|
+
_compilable_a_law_compress(x, A)
|
1195
|
+
return x
|
1164
1196
|
|
1165
1197
|
|
1166
1198
|
def identity(x):
|
@@ -1168,24 +1200,24 @@ def identity(x):
|
|
1168
1200
|
|
1169
1201
|
|
1170
1202
|
@decorator_knowngood
|
1171
|
-
def _compilable_trust_region_clip_(grad, lerp
|
1203
|
+
def _compilable_trust_region_clip_(grad, lerp, scale):
|
1172
1204
|
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
return [stochastic_round_(grad, g32) for grad, g32 in zip(grad, g32)]
|
1205
|
+
for x_ in grad:
|
1206
|
+
x = promote(x_)
|
1207
|
+
x = x / scale
|
1208
|
+
tanh = x.tanh()
|
1209
|
+
x = x.abs().log1p()
|
1210
|
+
x = x.copysign(tanh) * (1 - lerp) + tanh * lerp
|
1211
|
+
x = x * scale
|
1212
|
+
x = x.clamp(min=-2, max=2)
|
1213
|
+
copy_stochastic_(x_, x)
|
1183
1214
|
|
1184
1215
|
|
1185
1216
|
def trust_region_clip_(grad, lerp=0.9, scale=1.5):
|
1186
1217
|
grad = list_guard(grad)
|
1187
1218
|
lerp, scale = scalar_guard(lerp, scale, grad[0])
|
1188
|
-
|
1219
|
+
_compilable_trust_region_clip_(grad, lerp, scale)
|
1220
|
+
return grad
|
1189
1221
|
|
1190
1222
|
|
1191
1223
|
@decorator
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
+
heavyball/chainable.py,sha256=OK9fLde8LsrbjeL75amLXvCNwECVGVSDlHCcaNJEvyk,23104
|
3
|
+
heavyball/utils.py,sha256=ruiOh6AQvSxMpfWO97sgRVK1NYeqKHtg2U8op1kgOrY,48410
|
4
|
+
heavyball-1.3.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.3.1.dist-info/METADATA,sha256=EAexar-sE-vkzM0dQu6yrm-f7KQITROO0-B72mPkJIA,12022
|
6
|
+
heavyball-1.3.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.3.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.3.1.dist-info/RECORD,,
|
heavyball-1.2.3.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
|
2
|
-
heavyball/chainable.py,sha256=u9w2z_aSslcokWVCiiXQJ8GSPlOhgrOFUYAwt2JfTzI,21100
|
3
|
-
heavyball/utils.py,sha256=I2zfiB_-EP35LYr-vLyxPNl8_uJo2se3Id0IWjZeVjg,47951
|
4
|
-
heavyball-1.2.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.2.3.dist-info/METADATA,sha256=EtW_3QIUKrKpyYUfXmGQm3_EpZkr8oQyow7gAyC4Ges,12022
|
6
|
-
heavyball-1.2.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.2.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.2.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|