heavyball 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,18 +1,12 @@
1
- import copy
2
1
  import functools
3
2
  import gc
4
- import inspect
5
3
  import math
6
4
  import random
7
5
  import string
8
- import sys
9
- import time
10
6
  import warnings
11
- from datetime import datetime
12
7
  from typing import List, Optional, Tuple, Callable, Union
13
8
  from unittest.mock import patch
14
9
 
15
- import hyperopt
16
10
  import numpy as np
17
11
  import torch
18
12
  from torch import Tensor
@@ -165,14 +159,17 @@ def beta_debias(beta, step):
165
159
  return 1 - (1 - beta) / (1 - beta ** step)
166
160
 
167
161
 
162
+ def eps_sqrt(item, eps):
163
+ return item.sqrt().clamp(min=eps)
164
+
165
+
168
166
  @decorator_knowngood
169
167
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
170
168
  out: List[Optional[Tensor]]):
171
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
172
- s32 = [s * beta2 + g * g * (1 - beta2) for s, g in zip(s32, g32)]
173
- copy_stochastic_list_(state, s32)
169
+ g32 = promote(grad)
170
+ s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
174
171
 
175
- denom = [d.sqrt().clamp(min=eps) for d in s32]
172
+ denom = [eps_sqrt(d, eps) for d in s32]
176
173
 
177
174
  if out[0] is None:
178
175
  return denom
@@ -189,7 +186,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
189
186
 
190
187
  @decorator_knowngood
191
188
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
192
- g32 = list(map(promote, grad))
189
+ g32 = promote(grad)
193
190
  denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
194
191
  out = torch._foreach_div(g32, denom)
195
192
  copy_stochastic_list_(grad, out)
@@ -265,8 +262,8 @@ def set_torch(benchmark_limit: int = 32):
265
262
  cudnn.benchmark_limit = benchmark_limit
266
263
  torch.use_deterministic_algorithms(False)
267
264
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
268
- opt_einsum.enabled = True
269
- opt_einsum.strategy = "dp"
265
+ opt_einsum.enabled = False
266
+ opt_einsum.strategy = "auto"
270
267
 
271
268
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
272
269
  _ignore_warning(
@@ -379,7 +376,7 @@ def _compilable_scatter_set(target, source, index):
379
376
  target[:] = source.contiguous()[index].reshape_as(target)
380
377
 
381
378
 
382
- @decorator_knowngood
379
+ #@decorator_knowngood
383
380
  def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
384
381
  """
385
382
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -398,7 +395,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
398
395
  new_qs = []
399
396
 
400
397
  for m, q in zip(GG, Q):
401
- if len(m) == 0:
398
+ if m is None:
399
+ new_qs.append(None)
402
400
  continue
403
401
 
404
402
  m = promote(m.data)
@@ -420,19 +418,20 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
420
418
  in_str = einsum_base[:exp_avg.dim()]
421
419
  out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
422
420
 
423
- from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if len(m) > 0])
421
+ from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
424
422
  if not from_shampoo:
425
423
  return
426
424
 
427
- to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if len(m) > 0])
425
+ to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
428
426
  out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
429
427
 
430
428
  subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
431
- exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
429
+ exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None])
432
430
  copy_stochastic_(exp_avg, exp_avg_new)
433
431
 
434
432
  for q, q_new in zip(Q, new_qs):
435
- copy_stochastic_(q, q_new)
433
+ if q is not None:
434
+ copy_stochastic_(q, q_new)
436
435
 
437
436
 
438
437
  def get_orthogonal_matrix(mat):
@@ -442,8 +441,8 @@ def get_orthogonal_matrix(mat):
442
441
 
443
442
  final = []
444
443
  for m in mat:
445
- if len(m) == 0:
446
- final.append([])
444
+ if m is None:
445
+ final.append(None)
447
446
  continue
448
447
 
449
448
  m = promote(m.data)
@@ -476,7 +475,9 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
476
475
  for x_, y_ in zip(x, y):
477
476
  x32 = promote(x_)
478
477
  y32 = promote(y_)
479
- copy_stochastic_(x_, x32.lerp(y32, a))
478
+ if x32.dtype != y32.dtype:
479
+ y32 = y32.to(x32.dtype)
480
+ copy_stochastic_(x_, x32 * (1 - a) + y32 * a)
480
481
 
481
482
 
482
483
  def get_beta1(group):
@@ -575,7 +576,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
575
576
  g0 = einsum_base[:grad.dim()]
576
577
  g1 = g0.replace(b, b.upper())
577
578
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
578
- m.lerp_(outer_product, 1 - beta)
579
+ stochastic_lerp_(m, outer_product, 1 - beta)
579
580
 
580
581
 
581
582
  def tree_apply(fn):
@@ -618,7 +619,8 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
618
619
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
619
620
  if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
620
621
  for sh in grad.shape:
621
- if sh > max_precond_dim:
622
+ if sh > max_precond_dim or sh == 1:
623
+ # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
622
624
  state['GG'].append(None)
623
625
  else:
624
626
  state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
@@ -638,10 +640,10 @@ def project(grad, Q, back: bool):
638
640
  :return:
639
641
  """
640
642
  param = einsum_base[:grad.dim()]
641
- preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
643
+ preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
642
644
  if preconditioners:
643
645
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
644
- out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
646
+ out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
645
647
  grad = out.to(grad.dtype)
646
648
  return grad
647
649
 
@@ -876,7 +878,7 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
876
878
  ea32 = list(map(promote, state))
877
879
  grad = list(map(promote, grad))
878
880
  beta = promote(beta)
879
- ea32 = [e * beta + g * (1 - beta) for e, g in zip(ea32, grad)]
881
+ stochastic_lerp_(ea32, grad, 1 - beta)
880
882
  copy_stochastic_list_(state, ea32)
881
883
  return ea32
882
884
 
@@ -890,7 +892,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
890
892
  g32 = list(map(promote, grad))
891
893
  exp_avg32 = _lerp(exp_avg, g32, beta1)
892
894
  denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
893
- u32 = [ea / d for ea, d in zip(exp_avg32, denom)]
895
+ u32 = torch._foreach_div(exp_avg32, denom)
894
896
  copy_stochastic_list_(grad, u32)
895
897
 
896
898
 
@@ -973,14 +975,11 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
973
975
  _compilable_update_(y, u32, decay, lr, caution, g32)
974
976
 
975
977
  beta1 = beta_debias(beta1, step)
976
- denom = torch._foreach_sqrt(exp_avg_sq32)
977
- denom = [d.clamp(min=eps) for d in denom]
978
- exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
979
- copy_stochastic_list_(exp_avg, exp_avg32)
978
+ denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
979
+ stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
980
980
 
981
981
  beta2 = beta_debias(beta2, step + 1)
982
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
983
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
982
+ stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
984
983
 
985
984
 
986
985
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
@@ -990,27 +989,23 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
990
989
 
991
990
 
992
991
  @decorator_knowngood
993
- def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
992
+ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
994
993
  g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
995
994
  update = [e.clone() for e in exp_avg]
996
995
 
997
996
  beta1 = beta_debias(beta1, step)
998
- denom = torch._foreach_sqrt(exp_avg_sq32)
999
- denom = [d.clamp(min=1e-8) for d in denom]
1000
- exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
1001
- copy_stochastic_list_(exp_avg, exp_avg32)
997
+ denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
998
+ stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
1002
999
 
1003
- beta2 = beta_debias(beta2, step + 1)
1004
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
1005
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
1000
+ stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
1006
1001
 
1007
1002
  copy_stochastic_list_(grad, update)
1008
1003
 
1009
1004
 
1010
- def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
1005
+ def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps: float = 1e-8):
1011
1006
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
1012
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
1013
- _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
1007
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1008
+ _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps)
1014
1009
  return grad
1015
1010
 
1016
1011
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.6.0
3
+ Version: 1.6.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/HomebrewML/HeavyBall
6
6
  Author: HeavyBall Authors
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
+ heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
+ heavyball/utils.py,sha256=iQxSQjw_sgJp4AvX71VdTJxJ_20Tdu7W2tdrYu5q2EI,55808
4
+ heavyball-1.6.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.6.1.dist-info/METADATA,sha256=yFMCDJPpD5jVOFtL4l_pM3jTw3_ZizeTSQ_ugVHIWKM,43441
6
+ heavyball-1.6.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.6.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.6.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
- heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
- heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
4
- heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
6
- heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.6.0.dist-info/RECORD,,