heavyball 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,18 +1,12 @@
1
- import copy
2
1
  import functools
3
2
  import gc
4
- import inspect
5
3
  import math
6
4
  import random
7
5
  import string
8
- import sys
9
- import time
10
6
  import warnings
11
- from datetime import datetime
12
7
  from typing import List, Optional, Tuple, Callable, Union
13
8
  from unittest.mock import patch
14
9
 
15
- import hyperopt
16
10
  import numpy as np
17
11
  import torch
18
12
  from torch import Tensor
@@ -165,14 +159,17 @@ def beta_debias(beta, step):
165
159
  return 1 - (1 - beta) / (1 - beta ** step)
166
160
 
167
161
 
162
+ def eps_sqrt(item, eps):
163
+ return item.sqrt().clamp(min=eps)
164
+
165
+
168
166
  @decorator_knowngood
169
167
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
170
168
  out: List[Optional[Tensor]]):
171
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
172
- s32 = [s * beta2 + g * g * (1 - beta2) for s, g in zip(s32, g32)]
173
- copy_stochastic_list_(state, s32)
169
+ g32 = promote(grad)
170
+ s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
174
171
 
175
- denom = [d.sqrt().clamp(min=eps) for d in s32]
172
+ denom = [eps_sqrt(d, eps) for d in s32]
176
173
 
177
174
  if out[0] is None:
178
175
  return denom
@@ -189,7 +186,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
189
186
 
190
187
  @decorator_knowngood
191
188
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
192
- g32 = list(map(promote, grad))
189
+ g32 = promote(grad)
193
190
  denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
194
191
  out = torch._foreach_div(g32, denom)
195
192
  copy_stochastic_list_(grad, out)
@@ -265,8 +262,8 @@ def set_torch(benchmark_limit: int = 32):
265
262
  cudnn.benchmark_limit = benchmark_limit
266
263
  torch.use_deterministic_algorithms(False)
267
264
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
268
- opt_einsum.enabled = True
269
- opt_einsum.strategy = "dp"
265
+ opt_einsum.enabled = False
266
+ opt_einsum.strategy = "auto"
270
267
 
271
268
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
272
269
  _ignore_warning(
@@ -379,7 +376,7 @@ def _compilable_scatter_set(target, source, index):
379
376
  target[:] = source.contiguous()[index].reshape_as(target)
380
377
 
381
378
 
382
- @decorator_knowngood
379
+ # @decorator_knowngood
383
380
  def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
384
381
  """
385
382
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -398,7 +395,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
398
395
  new_qs = []
399
396
 
400
397
  for m, q in zip(GG, Q):
401
- if len(m) == 0:
398
+ if m is None:
399
+ new_qs.append(None)
402
400
  continue
403
401
 
404
402
  m = promote(m.data)
@@ -420,52 +418,60 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
420
418
  in_str = einsum_base[:exp_avg.dim()]
421
419
  out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
422
420
 
423
- from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if len(m) > 0])
421
+ from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
424
422
  if not from_shampoo:
425
423
  return
426
424
 
427
- to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if len(m) > 0])
425
+ to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
428
426
  out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
429
427
 
430
428
  subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
431
- exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
429
+ exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
430
+ *[q for q in new_qs if q is not None])
432
431
  copy_stochastic_(exp_avg, exp_avg_new)
433
432
 
434
433
  for q, q_new in zip(Q, new_qs):
435
- copy_stochastic_(q, q_new)
434
+ if q is not None:
435
+ copy_stochastic_(q, q_new)
436
436
 
437
437
 
438
- def get_orthogonal_matrix(mat):
438
+ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
439
439
  """
440
440
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
441
441
  """
442
442
 
443
443
  final = []
444
444
  for m in mat:
445
- if len(m) == 0:
446
- final.append([])
445
+ if m is None:
446
+ final.append(None)
447
447
  continue
448
448
 
449
449
  m = promote(m.data)
450
450
 
451
451
  device, dtype = m.device, m.dtype
452
- for modifier in (None, torch.double, 'cpu'):
453
- if modifier is not None:
454
- m = m.to(modifier)
452
+ eps = min_eps
453
+ while True:
455
454
  try:
456
- eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
455
+ eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
456
+ eigval, eigvec = torch.linalg.eigh(m + eps * eye)
457
457
  eigvec = eigvec.to(device=device, dtype=dtype)
458
458
  break
459
459
  except torch.OutOfMemoryError:
460
- pass
460
+ if m.device.type == 'cpu':
461
+ raise
462
+ else:
463
+ m = m.cpu()
461
464
  except RuntimeError: # failed to compute eigenvalues
462
- continue
465
+ if m.dtype != torch.double:
466
+ m = m.double()
467
+ elif eps < max_eps:
468
+ eps = eps ** (2 / 3)
469
+ else:
470
+ raise
463
471
  clean()
464
- else:
465
- raise RuntimeError("Failed to compute eigenvalues.")
466
472
 
473
+ eigvec = eigvec.to(device=m.device, dtype=m.dtype)
467
474
  eigvec = torch.flip(eigvec, [1])
468
-
469
475
  final.append(eigvec)
470
476
 
471
477
  return final
@@ -476,7 +482,9 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
476
482
  for x_, y_ in zip(x, y):
477
483
  x32 = promote(x_)
478
484
  y32 = promote(y_)
479
- copy_stochastic_(x_, x32.lerp(y32, a))
485
+ if x32.dtype != y32.dtype:
486
+ y32 = y32.to(x32.dtype)
487
+ copy_stochastic_(x_, x32 * (1 - a) + y32 * a)
480
488
 
481
489
 
482
490
  def get_beta1(group):
@@ -575,7 +583,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
575
583
  g0 = einsum_base[:grad.dim()]
576
584
  g1 = g0.replace(b, b.upper())
577
585
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
578
- m.lerp_(outer_product, 1 - beta)
586
+ stochastic_lerp_(m, outer_product, 1 - beta)
579
587
 
580
588
 
581
589
  def tree_apply(fn):
@@ -618,7 +626,8 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
618
626
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
619
627
  if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
620
628
  for sh in grad.shape:
621
- if sh > max_precond_dim:
629
+ if sh > max_precond_dim or sh == 1:
630
+ # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
622
631
  state['GG'].append(None)
623
632
  else:
624
633
  state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
@@ -638,10 +647,10 @@ def project(grad, Q, back: bool):
638
647
  :return:
639
648
  """
640
649
  param = einsum_base[:grad.dim()]
641
- preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
650
+ preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
642
651
  if preconditioners:
643
652
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
644
- out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
653
+ out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
645
654
  grad = out.to(grad.dtype)
646
655
  return grad
647
656
 
@@ -876,7 +885,7 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
876
885
  ea32 = list(map(promote, state))
877
886
  grad = list(map(promote, grad))
878
887
  beta = promote(beta)
879
- ea32 = [e * beta + g * (1 - beta) for e, g in zip(ea32, grad)]
888
+ stochastic_lerp_(ea32, grad, 1 - beta)
880
889
  copy_stochastic_list_(state, ea32)
881
890
  return ea32
882
891
 
@@ -890,7 +899,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
890
899
  g32 = list(map(promote, grad))
891
900
  exp_avg32 = _lerp(exp_avg, g32, beta1)
892
901
  denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
893
- u32 = [ea / d for ea, d in zip(exp_avg32, denom)]
902
+ u32 = torch._foreach_div(exp_avg32, denom)
894
903
  copy_stochastic_list_(grad, u32)
895
904
 
896
905
 
@@ -973,14 +982,11 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
973
982
  _compilable_update_(y, u32, decay, lr, caution, g32)
974
983
 
975
984
  beta1 = beta_debias(beta1, step)
976
- denom = torch._foreach_sqrt(exp_avg_sq32)
977
- denom = [d.clamp(min=eps) for d in denom]
978
- exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
979
- copy_stochastic_list_(exp_avg, exp_avg32)
985
+ denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
986
+ stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
980
987
 
981
988
  beta2 = beta_debias(beta2, step + 1)
982
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
983
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
989
+ stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
984
990
 
985
991
 
986
992
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
@@ -990,27 +996,23 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
990
996
 
991
997
 
992
998
  @decorator_knowngood
993
- def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
999
+ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
994
1000
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
995
1001
  update = [e.clone() for e in exp_avg]
996
1002
 
997
1003
  beta1 = beta_debias(beta1, step)
998
- denom = torch._foreach_sqrt(exp_avg_sq32)
999
- denom = [d.clamp(min=1e-8) for d in denom]
1000
- exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
1001
- copy_stochastic_list_(exp_avg, exp_avg32)
1004
+ denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
1005
+ stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
1002
1006
 
1003
- beta2 = beta_debias(beta2, step + 1)
1004
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
1005
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
1007
+ stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
1006
1008
 
1007
1009
  copy_stochastic_list_(grad, update)
1008
1010
 
1009
1011
 
1010
- def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
1012
+ def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps: float = 1e-8):
1011
1013
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
1012
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
1013
- _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
1014
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1015
+ _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps)
1014
1016
  return grad
1015
1017
 
1016
1018
 
@@ -1,26 +1,26 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: heavyball
3
- Version: 1.6.0
4
- Summary: Efficient optimizers
5
- Home-page: https://github.com/HomebrewML/HeavyBall
6
- Author: HeavyBall Authors
7
- Author-email: github.heavyball@nestler.sh
8
- License: BSD
9
- Classifier: Development Status :: 5 - Production/Stable
10
- Classifier: License :: OSI Approved :: BSD License
11
- Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Topic :: Software Development :: Libraries
16
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
3
+ Version: 1.6.2
4
+ Summary: Efficient Optimizers
5
+ Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
+ Project-URL: source, https://github.com/HomebrewML/HeavyBall
7
+ Project-URL: tracker, https://github.com/HomebrewML/HeavyBall/issues
8
+ Keywords: torch,optimizer,muon,soap,psgd
17
9
  Classifier: Intended Audience :: Developers
18
- Requires-Python: >=3.7
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: BSD License
12
+ Classifier: Natural Language :: English
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.9
19
16
  Description-Content-Type: text/markdown
20
17
  License-File: LICENSE
21
- Requires-Dist: opt-einsum
22
- Requires-Dist: torch
23
- Requires-Dist: numpy
18
+ Requires-Dist: opt-einsum>=3.0.0
19
+ Requires-Dist: torch>=2.0.0
20
+ Provides-Extra: dev
21
+ Requires-Dist: pre-commit; extra == "dev"
22
+ Requires-Dist: pytest; extra == "dev"
23
+ Requires-Dist: ruff; extra == "dev"
24
24
 
25
25
  # `heavyball`: Efficient Optimizers
26
26
 
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
+ heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
+ heavyball/utils.py,sha256=CFBFHTekWaqKhmrSLuMvRsxZ41YxPfsYihEPvJMKOQc,56088
4
+ heavyball-1.6.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.6.2.dist-info/METADATA,sha256=q2CEAHIg6jdGn7dey36EMExwJBrNFTDgZFpEzEHDvBY,43479
6
+ heavyball-1.6.2.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
7
+ heavyball-1.6.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.6.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
- heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
- heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
4
- heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
6
- heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.6.0.dist-info/RECORD,,