heavyball 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -163,18 +163,6 @@ class OrthoLaProp(C.BaseOpt):
163
163
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
164
 
165
165
 
166
- class OrthoAdamW(C.BaseOpt):
167
- def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
168
- foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
169
- mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
170
- update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
171
- defaults = locals()
172
- defaults.pop("self")
173
- params = defaults.pop("params")
174
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
175
- C.orthogonalize_grad_to_param, C.scale_by_adam)
176
-
177
-
178
166
  class ForeachPSGDKron(C.BaseOpt):
179
167
  """
180
168
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -193,7 +181,7 @@ class ForeachPSGDKron(C.BaseOpt):
193
181
  stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
194
182
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
195
183
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
196
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
184
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
197
185
  # expert parameters
198
186
  precond_init_scale=1.0, precond_lr=0.1):
199
187
  defaults = locals()
heavyball/chainable.py CHANGED
@@ -364,10 +364,12 @@ def _update_psgd_cache(cached, Q_cache, q):
364
364
  return Q_cache
365
365
 
366
366
 
367
- def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
367
+ def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
368
368
  if group.get('is_cached', False):
369
- return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
370
- return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
369
+ out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
370
+ out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
371
+ group['caution'] = False # we already cautioned here - shouldn't do it again
372
+ return out
371
373
 
372
374
 
373
375
  def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
@@ -387,7 +389,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
387
389
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
388
390
  Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
389
391
  update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
390
- return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
392
+ return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
391
393
 
392
394
 
393
395
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
@@ -395,7 +397,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
395
397
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
396
398
  prob: Optional[callable] = None):
397
399
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
398
- precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
400
+ precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
399
401
  _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
400
402
  Q_mat, Q, exprs, prob)
401
403
  return precond
@@ -467,6 +469,8 @@ class ChainOpt(utils.StatefulOptimizer):
467
469
  f'only supported with foreach=True (currently foreach={group["foreach"]}).')
468
470
  group['base_lr'] = group['lr']
469
471
 
472
+ caution = group['caution']
473
+
470
474
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
471
475
 
472
476
  if not vals:
@@ -492,6 +496,7 @@ class ChainOpt(utils.StatefulOptimizer):
492
496
  else:
493
497
  chain(self.state_, group, g, p, *self.fns)
494
498
 
499
+ group['caution'] = caution
495
500
  group['lr'] = group['prev_lr']
496
501
  group['step'] = None
497
502
 
heavyball/utils.py CHANGED
@@ -770,22 +770,23 @@ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
770
770
 
771
771
  @decorator_knowngood
772
772
  def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
773
- step: Tensor):
773
+ step: Tensor, eps: Tensor):
774
774
  beta1 = beta_debias(beta1, step)
775
775
  beta2 = beta_debias(beta2, step)
776
776
 
777
777
  g32 = list(map(promote, grad))
778
778
 
779
779
  exp_avg32 = _lerp32(exp_avg, g32, beta1)
780
- denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
780
+ denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
781
781
  u32 = torch._foreach_div(exp_avg32, denom)
782
782
  copy_stochastic_list_(grad, u32)
783
783
 
784
784
 
785
- def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
785
+ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
786
+ eps: float):
786
787
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
787
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
788
- _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
788
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
789
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
789
790
  return grad
790
791
 
791
792
 
@@ -1299,7 +1300,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1299
1300
 
1300
1301
 
1301
1302
  @decorator_knowngood
1302
- def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1303
+ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1304
+ cast: bool = True):
1305
+ if caution:
1306
+ ea = _compilable_cautioning(grad, ea)
1303
1307
  md = min_dtype(list(cached_q) + [ea])
1304
1308
  args = [q.to(md) for q in cached_q]
1305
1309
  args = args + [ea.to(md)]
@@ -1311,8 +1315,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
1311
1315
 
1312
1316
  @decorator_knowngood
1313
1317
  def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1314
- precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1315
- update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1318
+ precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
1319
+ update_param_(param, precond, lr, decay, caution=False)
1316
1320
 
1317
1321
 
1318
1322
  def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
@@ -1321,7 +1325,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
1321
1325
 
1322
1326
 
1323
1327
  @decorator_knowngood
1324
- def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1328
+ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
1329
+ if caution:
1330
+ ea = _compilable_cautioning(grad, ea)
1325
1331
  md = min_dtype(list(preconds) + [ea])
1326
1332
  args = [q.to(md) for q in preconds]
1327
1333
  args = args + args + [ea.to(md)]
@@ -1331,8 +1337,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1331
1337
 
1332
1338
  @decorator_knowngood
1333
1339
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1334
- precond = psgd_precond_grad(expr, ea, *preconds)
1335
- update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1340
+ precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
1341
+ update_param_(param, precond, lr, decay, caution=False, grad=grad)
1336
1342
 
1337
1343
 
1338
1344
  def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.0
3
+ Version: 1.5.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=f0wWIjsibgA4_YwkPP8HFD7-snggYsAOFc84W0WnNMA,12049
2
+ heavyball/chainable.py,sha256=ygeQU-t3RT0Q1BWrEQ_0b4SlXYy8aGDt0DCZAfbiNiw,25040
3
+ heavyball/utils.py,sha256=D7ENwrIex_dgFiUHezymmsIdruoQ4_hYztIolCXo2KE,50636
4
+ heavyball-1.5.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.5.2.dist-info/METADATA,sha256=n_2fW7Wcz_btxBRWFibTe8wnM10B2su100bJzW0bfZY,43584
6
+ heavyball-1.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.5.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.5.2.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=AL3oSbNB1HQ0cwEG6aPZGVMbpCXXCYOxREX7JwK4Byc,12773
2
- heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
3
- heavyball/utils.py,sha256=NFvQcQemNOugH1vAi_UH3jnnttPSgVopmS1q6jbhxkQ,50289
4
- heavyball-1.5.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.0.dist-info/METADATA,sha256=fUOCJvDcBQ5280TCLhUCuIRwNVMvp3ysp4qrDuJCUeI,43584
6
- heavyball-1.5.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.0.dist-info/RECORD,,