heavyball 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import functools
2
2
  import random
3
- import warnings
4
3
  from typing import Optional, Union, Literal
5
4
 
6
5
  import torch
@@ -85,10 +84,11 @@ class CopyGuard(FunctionTransform):
85
84
 
86
85
 
87
86
  class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
88
- def __init__(self, fn, names, init_fn):
87
+ def __init__(self, fn, names, init_fn, skip_first: bool = True):
89
88
  super().__init__(fn)
90
89
  self.names = names
91
90
  self.init_fn = init_fn
91
+ self.skip_first = skip_first
92
92
 
93
93
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
94
94
  vars = []
@@ -97,7 +97,7 @@ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the ge
97
97
  st = state(p)
98
98
  skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
99
99
  vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
100
- if skip_update:
100
+ if skip_update and self.skip_first:
101
101
  raise SkipUpdate
102
102
  return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
103
103
 
@@ -109,8 +109,17 @@ class NoState(FunctionTransform):
109
109
 
110
110
  class NoStateNoForeach(FunctionTransform):
111
111
  def __call__(self, state, group, update, grad, param, *args, **kwargs):
112
+ updates = []
113
+ skip_update = False
112
114
  for a in zip(update, grad, param, *args):
113
- return self.fn(group, *a, **kwargs)
115
+ try:
116
+ updates.append(self.fn(group, *a, **kwargs))
117
+ except SkipUpdate:
118
+ skip_update = True
119
+ pass
120
+ if skip_update:
121
+ raise SkipUpdate
122
+ return updates
114
123
 
115
124
 
116
125
  def zero_guard(*names):
@@ -118,11 +127,11 @@ def zero_guard(*names):
118
127
 
119
128
 
120
129
  def copy_guard(index, *names):
121
- return functools.partial(CopyGuard, index=index, names=names)
130
+ return functools.partial(CopyGuard, index=index, names=names,)
122
131
 
123
132
 
124
- def general_guard(*names, init_fn):
125
- return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
133
+ def general_guard(*names, init_fn, skip_first: bool = True):
134
+ return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first)
126
135
 
127
136
 
128
137
  def no_state(fn):
@@ -311,18 +320,18 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
311
320
  if prob is None:
312
321
  prob = utils.precond_update_prob_schedule()
313
322
  if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
314
- return
323
+ return Q_mat
315
324
 
316
- Q = [utils.promote(q_) for q_ in Q]
317
325
  utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
318
326
 
319
- if grad.dim() > 1 and precond_schedule(group, balance_probability, "balance_prob"):
327
+ if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
320
328
  if group['store_triu_as_line']:
321
329
  utils.psgd_balance_Q([q_ for _, q_ in Q])
322
330
  else:
323
331
  utils.psgd_balance_Q(Q)
324
332
 
325
- _update_psgd_cache(cached, Q_cache, Q_mat)
333
+ return _update_psgd_cache(cached, Q_cache, Q_mat)
334
+
326
335
 
327
336
  def _update_psgd_cache(cached, Q_cache, q):
328
337
  if not cached:
@@ -351,44 +360,47 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
351
360
  group['caution'], *Q_mat)
352
361
 
353
362
 
354
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
363
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
355
364
  @no_state_no_foreach
356
365
  def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
357
366
  prob: Optional[callable] = None):
358
- old = update
359
367
  update = update.to(memory_format=torch.contiguous_format)
360
368
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
361
- _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
369
+ Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
370
+ update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
362
371
  return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
363
372
 
364
373
 
365
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
374
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
366
375
  @no_state_no_foreach
367
376
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
368
377
  prob: Optional[callable] = None):
369
378
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
370
379
  precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
371
- _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
380
+ _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
381
+ Q_mat, Q, exprs, prob)
372
382
  return precond
373
383
 
374
384
 
375
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
385
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
376
386
  @no_state_no_foreach
377
387
  def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
378
388
  prob: Optional[callable] = None):
379
389
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
380
- _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
390
+ Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
391
+ update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
381
392
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
382
393
  raise SkipUpdate
383
394
 
384
395
 
385
- @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
396
+ @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
386
397
  @no_state_no_foreach
387
398
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
388
399
  prob: Optional[callable] = None):
389
400
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
390
401
  _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
391
- _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob)
402
+ _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
403
+ Q_mat, Q, exprs, prob)
392
404
  raise SkipUpdate
393
405
 
394
406
 
@@ -422,7 +434,6 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
422
434
 
423
435
 
424
436
  class ChainOpt(utils.StatefulOptimizer):
425
- compile_step: bool = False
426
437
  promote: bool = False
427
438
 
428
439
  def __init__(self, params, defaults, foreach: bool, *fns):
@@ -432,6 +443,10 @@ class ChainOpt(utils.StatefulOptimizer):
432
443
  def _step(self, group):
433
444
  if 'base_lr' not in group:
434
445
  group['base_lr'] = group['lr']
446
+ if 'prev_lr' in group and group['prev_lr'] != group['lr']:
447
+ utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and '
448
+ f'only supported with foreach=True (currently foreach={group["foreach"]}).')
449
+ group['base_lr'] = group['lr']
435
450
 
436
451
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
437
452
  if not vals:
@@ -451,9 +466,10 @@ class ChainOpt(utils.StatefulOptimizer):
451
466
  group['step'] = state['step'] = step = step + 1
452
467
 
453
468
  if group['warmup_steps'] and step < group['warmup_steps']:
454
- group['lr'] = group['base_lr'] * step / group['warmup_steps']
469
+ group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']
470
+
455
471
  else:
456
- group['lr'] = group['base_lr']
472
+ group['prev_lr'] = group['lr'] = group['base_lr']
457
473
 
458
474
  if not group['foreach'] or len(p) == 1:
459
475
  for param, grad in zip(p, g):
@@ -461,7 +477,7 @@ class ChainOpt(utils.StatefulOptimizer):
461
477
  else:
462
478
  chain(self.state_, group, g, p, *self.fns)
463
479
 
464
- group['lr'] = None
480
+ group['lr'] = group['prev_lr']
465
481
  group['step'] = None
466
482
 
467
483
 
heavyball/utils.py CHANGED
@@ -193,12 +193,12 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
193
193
  return grad
194
194
 
195
195
 
196
- # TODO: This lerp was fucked - check other lerps
197
196
  @decorator_knowngood
198
197
  def _compilable_exp_avg_(state, grad, beta):
199
- s32 = [s.lerp(g, 1 - beta) for s, g in zip(promote(state), promote(grad))]
200
- copy_stochastic_list_(state, s32)
201
- copy_stochastic_list_(grad, s32)
198
+ for s, g in zip(state, grad):
199
+ lerped = s.lerp(g, 1 - beta)
200
+ copy_stochastic_(s, lerped)
201
+ copy_stochastic_(g, lerped)
202
202
 
203
203
 
204
204
  def scale_by_exp_avg_(state, grad, beta):
@@ -592,6 +592,7 @@ def project(grad, Q, back: bool):
592
592
 
593
593
  class StatefulOptimizer(torch.optim.Optimizer):
594
594
  ema_decay: float = 0.001
595
+ compile_step: bool = False
595
596
 
596
597
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
597
598
  super().__init__(params, {**defaults, 'foreach': foreach})
@@ -637,6 +638,10 @@ class StatefulOptimizer(torch.optim.Optimizer):
637
638
 
638
639
  p.grad = None
639
640
 
641
+ if self.compile_step:
642
+ yield p, grad
643
+ continue
644
+
640
645
  p_views = merge_group(group, p)
641
646
  if grad is not None:
642
647
  grad = merge_group(group, grad)
@@ -1030,7 +1035,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q):
1030
1035
  V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
1031
1036
  eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1032
1037
  eps *= G.norm() / G.numel()
1033
- G += V * eps
1038
+ G = G + V * eps
1034
1039
  md = min_dtype(Q + [G])
1035
1040
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1036
1041
  order = G.dim()
@@ -1078,26 +1083,20 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
1078
1083
  term1 = promote(torch.einsum(exprG, A, A))
1079
1084
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1080
1085
 
1081
- term2 += term1 # a + b
1082
- term1 *= 2 # 2a
1083
- if term1.dtype == term2.dtype:
1084
- term1 -= term2 # 2a - (a + b) == a - b
1085
- else:
1086
- term1 = term1 - term2
1086
+ term1, term2 = term1 - term2, term1 + term2
1087
1087
 
1088
1088
  term1 *= precond_lr
1089
1089
  norm = term2.norm(float('inf'))
1090
1090
  if q.dim() < 2:
1091
- term1 *= q.to(term1.dtype)
1092
- term1 /= norm.clamp_(min=tiny_bf16)
1091
+ term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
1093
1092
  else:
1094
1093
  torch.triu(term1, out=term1)
1095
- term1 /= psgd_lb(term2, norm).clamp_(tiny_bf16)
1096
- torch.matmul(term1, q, out=term1)
1094
+ term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1095
+ term1 = torch.mm(term1, q)
1097
1096
  if store_triu_as_line:
1098
1097
  term1 = triu_to_line([term1])[0][1]
1099
1098
  o = o[1]
1100
- stochastic_add_([o], [term1], -1)
1099
+ stochastic_add_(o, term1, -1)
1101
1100
 
1102
1101
 
1103
1102
  @decorator_knowngood
@@ -1162,7 +1161,7 @@ def mu_law_compress(x, mu=127.0):
1162
1161
  """
1163
1162
  x = list_guard(x)
1164
1163
  mu = scalar_guard(mu, x[0])
1165
- _compilable_mu_law_compress(x, mu)
1164
+ _compilable_mu_law_compress_(x, mu)
1166
1165
  return x
1167
1166
 
1168
1167
 
@@ -1191,7 +1190,7 @@ def a_law_compress(x, A=87.6):
1191
1190
  """
1192
1191
  x = list_guard(x)
1193
1192
  A = scalar_guard(A, x[0])
1194
- _compilable_a_law_compress(x, A)
1193
+ _compilable_a_law_compress_(x, A)
1195
1194
  return x
1196
1195
 
1197
1196
 
@@ -1295,6 +1294,7 @@ def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, gra
1295
1294
 
1296
1295
 
1297
1296
  def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1297
+
1298
1298
  lr = scalar_guard(lr, param[0])
1299
1299
  _compilable_fused_precond_grad_cached_(expr, ea, param, lr, grad, decay, caution, *cached_q)
1300
1300
 
@@ -1310,7 +1310,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1310
1310
 
1311
1311
  @decorator_knowngood
1312
1312
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1313
- precond = psgd_precond_grad(expr, grad, *preconds)
1313
+ precond = psgd_precond_grad(expr, ea, *preconds)
1314
1314
  update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1315
1315
 
1316
1316
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.3.1
3
+ Version: 1.4.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
3
+ heavyball/utils.py,sha256=djwaSLZOB8B-xD2jJxZfXTJpJrcWp-mWTmKxC2F5Sh0,48330
4
+ heavyball-1.4.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.4.1.dist-info/METADATA,sha256=jd7AC5ywdThr-09cCh38LZF4s-kL86Sc6PpZE-LN1iI,12022
6
+ heavyball-1.4.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.4.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.4.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=OK9fLde8LsrbjeL75amLXvCNwECVGVSDlHCcaNJEvyk,23104
3
- heavyball/utils.py,sha256=ruiOh6AQvSxMpfWO97sgRVK1NYeqKHtg2U8op1kgOrY,48410
4
- heavyball-1.3.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.3.1.dist-info/METADATA,sha256=EAexar-sE-vkzM0dQu6yrm-f7KQITROO0-B72mPkJIA,12022
6
- heavyball-1.3.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.3.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.3.1.dist-info/RECORD,,