heavyball 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +1 -13
- heavyball/chainable.py +10 -5
- heavyball/utils.py +17 -11
- {heavyball-1.5.0.dist-info → heavyball-1.5.2.dist-info}/METADATA +1 -1
- heavyball-1.5.2.dist-info/RECORD +8 -0
- heavyball-1.5.0.dist-info/RECORD +0 -8
- {heavyball-1.5.0.dist-info → heavyball-1.5.2.dist-info}/LICENSE +0 -0
- {heavyball-1.5.0.dist-info → heavyball-1.5.2.dist-info}/WHEEL +0 -0
- {heavyball-1.5.0.dist-info → heavyball-1.5.2.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -163,18 +163,6 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
163
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
164
|
|
165
165
|
|
166
|
-
class OrthoAdamW(C.BaseOpt):
|
167
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
168
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
169
|
-
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
170
|
-
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
171
|
-
defaults = locals()
|
172
|
-
defaults.pop("self")
|
173
|
-
params = defaults.pop("params")
|
174
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
175
|
-
C.orthogonalize_grad_to_param, C.scale_by_adam)
|
176
|
-
|
177
|
-
|
178
166
|
class ForeachPSGDKron(C.BaseOpt):
|
179
167
|
"""
|
180
168
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -193,7 +181,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
193
181
|
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
194
182
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
195
183
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
196
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
184
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
197
185
|
# expert parameters
|
198
186
|
precond_init_scale=1.0, precond_lr=0.1):
|
199
187
|
defaults = locals()
|
heavyball/chainable.py
CHANGED
@@ -364,10 +364,12 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
364
364
|
return Q_cache
|
365
365
|
|
366
366
|
|
367
|
-
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
|
367
|
+
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
|
368
368
|
if group.get('is_cached', False):
|
369
|
-
|
370
|
-
|
369
|
+
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
|
370
|
+
out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
|
371
|
+
group['caution'] = False # we already cautioned here - shouldn't do it again
|
372
|
+
return out
|
371
373
|
|
372
374
|
|
373
375
|
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
|
@@ -387,7 +389,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
387
389
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
388
390
|
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
389
391
|
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
390
|
-
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
392
|
+
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
391
393
|
|
392
394
|
|
393
395
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
@@ -395,7 +397,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
395
397
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
396
398
|
prob: Optional[callable] = None):
|
397
399
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
398
|
-
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
400
|
+
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
399
401
|
_ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
400
402
|
Q_mat, Q, exprs, prob)
|
401
403
|
return precond
|
@@ -467,6 +469,8 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
467
469
|
f'only supported with foreach=True (currently foreach={group["foreach"]}).')
|
468
470
|
group['base_lr'] = group['lr']
|
469
471
|
|
472
|
+
caution = group['caution']
|
473
|
+
|
470
474
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
471
475
|
|
472
476
|
if not vals:
|
@@ -492,6 +496,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
492
496
|
else:
|
493
497
|
chain(self.state_, group, g, p, *self.fns)
|
494
498
|
|
499
|
+
group['caution'] = caution
|
495
500
|
group['lr'] = group['prev_lr']
|
496
501
|
group['step'] = None
|
497
502
|
|
heavyball/utils.py
CHANGED
@@ -770,22 +770,23 @@ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
|
|
770
770
|
|
771
771
|
@decorator_knowngood
|
772
772
|
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
773
|
-
step: Tensor):
|
773
|
+
step: Tensor, eps: Tensor):
|
774
774
|
beta1 = beta_debias(beta1, step)
|
775
775
|
beta2 = beta_debias(beta2, step)
|
776
776
|
|
777
777
|
g32 = list(map(promote, grad))
|
778
778
|
|
779
779
|
exp_avg32 = _lerp32(exp_avg, g32, beta1)
|
780
|
-
denom = exp_avg_sq_(exp_avg_sq, g32, beta2,
|
780
|
+
denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
|
781
781
|
u32 = torch._foreach_div(exp_avg32, denom)
|
782
782
|
copy_stochastic_list_(grad, u32)
|
783
783
|
|
784
784
|
|
785
|
-
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int
|
785
|
+
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
786
|
+
eps: float):
|
786
787
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
787
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
788
|
-
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
788
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
789
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
789
790
|
return grad
|
790
791
|
|
791
792
|
|
@@ -1299,7 +1300,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1299
1300
|
|
1300
1301
|
|
1301
1302
|
@decorator_knowngood
|
1302
|
-
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor,
|
1303
|
+
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
|
1304
|
+
cast: bool = True):
|
1305
|
+
if caution:
|
1306
|
+
ea = _compilable_cautioning(grad, ea)
|
1303
1307
|
md = min_dtype(list(cached_q) + [ea])
|
1304
1308
|
args = [q.to(md) for q in cached_q]
|
1305
1309
|
args = args + [ea.to(md)]
|
@@ -1311,8 +1315,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
|
|
1311
1315
|
|
1312
1316
|
@decorator_knowngood
|
1313
1317
|
def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1314
|
-
precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
|
1315
|
-
update_param_(param, precond, lr, decay, caution=
|
1318
|
+
precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
|
1319
|
+
update_param_(param, precond, lr, decay, caution=False)
|
1316
1320
|
|
1317
1321
|
|
1318
1322
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
@@ -1321,7 +1325,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
|
|
1321
1325
|
|
1322
1326
|
|
1323
1327
|
@decorator_knowngood
|
1324
|
-
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
1328
|
+
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
|
1329
|
+
if caution:
|
1330
|
+
ea = _compilable_cautioning(grad, ea)
|
1325
1331
|
md = min_dtype(list(preconds) + [ea])
|
1326
1332
|
args = [q.to(md) for q in preconds]
|
1327
1333
|
args = args + args + [ea.to(md)]
|
@@ -1331,8 +1337,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
|
1331
1337
|
|
1332
1338
|
@decorator_knowngood
|
1333
1339
|
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1334
|
-
precond = psgd_precond_grad(expr, ea, *preconds)
|
1335
|
-
update_param_(param, precond, lr, decay, caution=
|
1340
|
+
precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
|
1341
|
+
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
1336
1342
|
|
1337
1343
|
|
1338
1344
|
def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=f0wWIjsibgA4_YwkPP8HFD7-snggYsAOFc84W0WnNMA,12049
|
2
|
+
heavyball/chainable.py,sha256=ygeQU-t3RT0Q1BWrEQ_0b4SlXYy8aGDt0DCZAfbiNiw,25040
|
3
|
+
heavyball/utils.py,sha256=D7ENwrIex_dgFiUHezymmsIdruoQ4_hYztIolCXo2KE,50636
|
4
|
+
heavyball-1.5.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.5.2.dist-info/METADATA,sha256=n_2fW7Wcz_btxBRWFibTe8wnM10B2su100bJzW0bfZY,43584
|
6
|
+
heavyball-1.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.5.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.5.2.dist-info/RECORD,,
|
heavyball-1.5.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=AL3oSbNB1HQ0cwEG6aPZGVMbpCXXCYOxREX7JwK4Byc,12773
|
2
|
-
heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
|
3
|
-
heavyball/utils.py,sha256=NFvQcQemNOugH1vAi_UH3jnnttPSgVopmS1q6jbhxkQ,50289
|
4
|
-
heavyball-1.5.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.0.dist-info/METADATA,sha256=fUOCJvDcBQ5280TCLhUCuIRwNVMvp3ysp4qrDuJCUeI,43584
|
6
|
-
heavyball-1.5.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|