heavyball 1.5.1__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +1 -36
- heavyball/chainable.py +10 -5
- heavyball/utils.py +11 -6
- {heavyball-1.5.1.dist-info → heavyball-1.5.2.dist-info}/METADATA +1 -1
- heavyball-1.5.2.dist-info/RECORD +8 -0
- heavyball-1.5.1.dist-info/RECORD +0 -8
- {heavyball-1.5.1.dist-info → heavyball-1.5.2.dist-info}/LICENSE +0 -0
- {heavyball-1.5.1.dist-info → heavyball-1.5.2.dist-info}/WHEEL +0 -0
- {heavyball-1.5.1.dist-info → heavyball-1.5.2.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -163,41 +163,6 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
163
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
164
|
|
165
165
|
|
166
|
-
class ForeachAdamW(C.BaseOpt):
|
167
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
168
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
169
|
-
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
170
|
-
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
171
|
-
defaults = locals()
|
172
|
-
defaults.pop("self")
|
173
|
-
params = defaults.pop("params")
|
174
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
175
|
-
|
176
|
-
|
177
|
-
class OrthoAdamW(C.BaseOpt):
|
178
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
179
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
180
|
-
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
181
|
-
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
182
|
-
defaults = locals()
|
183
|
-
defaults.pop("self")
|
184
|
-
params = defaults.pop("params")
|
185
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
186
|
-
C.orthogonalize_grad_to_param, C.scale_by_adam)
|
187
|
-
|
188
|
-
|
189
|
-
class AdamWOrtho(C.BaseOpt):
|
190
|
-
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
191
|
-
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
192
|
-
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
193
|
-
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
194
|
-
defaults = locals()
|
195
|
-
defaults.pop("self")
|
196
|
-
params = defaults.pop("params")
|
197
|
-
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
|
198
|
-
C.orthogonalize_grad_to_param)
|
199
|
-
|
200
|
-
|
201
166
|
class ForeachPSGDKron(C.BaseOpt):
|
202
167
|
"""
|
203
168
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -216,7 +181,7 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
216
181
|
stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
217
182
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
218
183
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
219
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
184
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
220
185
|
# expert parameters
|
221
186
|
precond_init_scale=1.0, precond_lr=0.1):
|
222
187
|
defaults = locals()
|
heavyball/chainable.py
CHANGED
@@ -364,10 +364,12 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
364
364
|
return Q_cache
|
365
365
|
|
366
366
|
|
367
|
-
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
|
367
|
+
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
|
368
368
|
if group.get('is_cached', False):
|
369
|
-
|
370
|
-
|
369
|
+
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
|
370
|
+
out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
|
371
|
+
group['caution'] = False # we already cautioned here - shouldn't do it again
|
372
|
+
return out
|
371
373
|
|
372
374
|
|
373
375
|
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
|
@@ -387,7 +389,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
387
389
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
388
390
|
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
|
389
391
|
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
|
390
|
-
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
392
|
+
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
391
393
|
|
392
394
|
|
393
395
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
@@ -395,7 +397,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
395
397
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
396
398
|
prob: Optional[callable] = None):
|
397
399
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
398
|
-
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
|
400
|
+
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
|
399
401
|
_ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
|
400
402
|
Q_mat, Q, exprs, prob)
|
401
403
|
return precond
|
@@ -467,6 +469,8 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
467
469
|
f'only supported with foreach=True (currently foreach={group["foreach"]}).')
|
468
470
|
group['base_lr'] = group['lr']
|
469
471
|
|
472
|
+
caution = group['caution']
|
473
|
+
|
470
474
|
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
|
471
475
|
|
472
476
|
if not vals:
|
@@ -492,6 +496,7 @@ class ChainOpt(utils.StatefulOptimizer):
|
|
492
496
|
else:
|
493
497
|
chain(self.state_, group, g, p, *self.fns)
|
494
498
|
|
499
|
+
group['caution'] = caution
|
495
500
|
group['lr'] = group['prev_lr']
|
496
501
|
group['step'] = None
|
497
502
|
|
heavyball/utils.py
CHANGED
@@ -1300,7 +1300,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1300
1300
|
|
1301
1301
|
|
1302
1302
|
@decorator_knowngood
|
1303
|
-
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor,
|
1303
|
+
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
|
1304
|
+
cast: bool = True):
|
1305
|
+
if caution:
|
1306
|
+
ea = _compilable_cautioning(grad, ea)
|
1304
1307
|
md = min_dtype(list(cached_q) + [ea])
|
1305
1308
|
args = [q.to(md) for q in cached_q]
|
1306
1309
|
args = args + [ea.to(md)]
|
@@ -1312,8 +1315,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
|
|
1312
1315
|
|
1313
1316
|
@decorator_knowngood
|
1314
1317
|
def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
1315
|
-
precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
|
1316
|
-
update_param_(param, precond, lr, decay, caution=
|
1318
|
+
precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
|
1319
|
+
update_param_(param, precond, lr, decay, caution=False)
|
1317
1320
|
|
1318
1321
|
|
1319
1322
|
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
|
@@ -1322,7 +1325,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
|
|
1322
1325
|
|
1323
1326
|
|
1324
1327
|
@decorator_knowngood
|
1325
|
-
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
1328
|
+
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
|
1329
|
+
if caution:
|
1330
|
+
ea = _compilable_cautioning(grad, ea)
|
1326
1331
|
md = min_dtype(list(preconds) + [ea])
|
1327
1332
|
args = [q.to(md) for q in preconds]
|
1328
1333
|
args = args + args + [ea.to(md)]
|
@@ -1332,8 +1337,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
|
|
1332
1337
|
|
1333
1338
|
@decorator_knowngood
|
1334
1339
|
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
1335
|
-
precond = psgd_precond_grad(expr, ea, *preconds)
|
1336
|
-
update_param_(param, precond, lr, decay, caution=
|
1340
|
+
precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
|
1341
|
+
update_param_(param, precond, lr, decay, caution=False, grad=grad)
|
1337
1342
|
|
1338
1343
|
|
1339
1344
|
def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=f0wWIjsibgA4_YwkPP8HFD7-snggYsAOFc84W0WnNMA,12049
|
2
|
+
heavyball/chainable.py,sha256=ygeQU-t3RT0Q1BWrEQ_0b4SlXYy8aGDt0DCZAfbiNiw,25040
|
3
|
+
heavyball/utils.py,sha256=D7ENwrIex_dgFiUHezymmsIdruoQ4_hYztIolCXo2KE,50636
|
4
|
+
heavyball-1.5.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.5.2.dist-info/METADATA,sha256=n_2fW7Wcz_btxBRWFibTe8wnM10B2su100bJzW0bfZY,43584
|
6
|
+
heavyball-1.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.5.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.5.2.dist-info/RECORD,,
|
heavyball-1.5.1.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=fz-jC7m7XIYNf4PRaJ0rkSnWPYzMWEK5JQl4vp_yw_w,14166
|
2
|
-
heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
|
3
|
-
heavyball/utils.py,sha256=hae6gPVONG5lZiKm-Wqk0Sjjq3prfZIjCP5UoWcpptA,50338
|
4
|
-
heavyball-1.5.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.1.dist-info/METADATA,sha256=ww9KSe8MJDnjz1blmtnubpE20bkuXJ8NeMOeDK40OJk,43584
|
6
|
-
heavyball-1.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|