heavyball 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/utils.py +40 -45
- {heavyball-1.6.0.dist-info → heavyball-1.6.1.dist-info}/METADATA +1 -1
- heavyball-1.6.1.dist-info/RECORD +8 -0
- heavyball-1.6.0.dist-info/RECORD +0 -8
- {heavyball-1.6.0.dist-info → heavyball-1.6.1.dist-info}/LICENSE +0 -0
- {heavyball-1.6.0.dist-info → heavyball-1.6.1.dist-info}/WHEEL +0 -0
- {heavyball-1.6.0.dist-info → heavyball-1.6.1.dist-info}/top_level.txt +0 -0
heavyball/utils.py
CHANGED
@@ -1,18 +1,12 @@
|
|
1
|
-
import copy
|
2
1
|
import functools
|
3
2
|
import gc
|
4
|
-
import inspect
|
5
3
|
import math
|
6
4
|
import random
|
7
5
|
import string
|
8
|
-
import sys
|
9
|
-
import time
|
10
6
|
import warnings
|
11
|
-
from datetime import datetime
|
12
7
|
from typing import List, Optional, Tuple, Callable, Union
|
13
8
|
from unittest.mock import patch
|
14
9
|
|
15
|
-
import hyperopt
|
16
10
|
import numpy as np
|
17
11
|
import torch
|
18
12
|
from torch import Tensor
|
@@ -165,14 +159,17 @@ def beta_debias(beta, step):
|
|
165
159
|
return 1 - (1 - beta) / (1 - beta ** step)
|
166
160
|
|
167
161
|
|
162
|
+
def eps_sqrt(item, eps):
|
163
|
+
return item.sqrt().clamp(min=eps)
|
164
|
+
|
165
|
+
|
168
166
|
@decorator_knowngood
|
169
167
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
170
168
|
out: List[Optional[Tensor]]):
|
171
|
-
|
172
|
-
s32 =
|
173
|
-
copy_stochastic_list_(state, s32)
|
169
|
+
g32 = promote(grad)
|
170
|
+
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
174
171
|
|
175
|
-
denom = [d
|
172
|
+
denom = [eps_sqrt(d, eps) for d in s32]
|
176
173
|
|
177
174
|
if out[0] is None:
|
178
175
|
return denom
|
@@ -189,7 +186,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
189
186
|
|
190
187
|
@decorator_knowngood
|
191
188
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
192
|
-
g32 =
|
189
|
+
g32 = promote(grad)
|
193
190
|
denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
|
194
191
|
out = torch._foreach_div(g32, denom)
|
195
192
|
copy_stochastic_list_(grad, out)
|
@@ -265,8 +262,8 @@ def set_torch(benchmark_limit: int = 32):
|
|
265
262
|
cudnn.benchmark_limit = benchmark_limit
|
266
263
|
torch.use_deterministic_algorithms(False)
|
267
264
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
268
|
-
opt_einsum.enabled =
|
269
|
-
opt_einsum.strategy = "
|
265
|
+
opt_einsum.enabled = False
|
266
|
+
opt_einsum.strategy = "auto"
|
270
267
|
|
271
268
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
272
269
|
_ignore_warning(
|
@@ -379,7 +376,7 @@ def _compilable_scatter_set(target, source, index):
|
|
379
376
|
target[:] = source.contiguous()[index].reshape_as(target)
|
380
377
|
|
381
378
|
|
382
|
-
|
379
|
+
#@decorator_knowngood
|
383
380
|
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
|
384
381
|
"""
|
385
382
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
@@ -398,7 +395,8 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
398
395
|
new_qs = []
|
399
396
|
|
400
397
|
for m, q in zip(GG, Q):
|
401
|
-
if
|
398
|
+
if m is None:
|
399
|
+
new_qs.append(None)
|
402
400
|
continue
|
403
401
|
|
404
402
|
m = promote(m.data)
|
@@ -420,19 +418,20 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
420
418
|
in_str = einsum_base[:exp_avg.dim()]
|
421
419
|
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
422
420
|
|
423
|
-
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if
|
421
|
+
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
424
422
|
if not from_shampoo:
|
425
423
|
return
|
426
424
|
|
427
|
-
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if
|
425
|
+
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
|
428
426
|
out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
429
427
|
|
430
428
|
subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
|
431
|
-
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
|
429
|
+
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None])
|
432
430
|
copy_stochastic_(exp_avg, exp_avg_new)
|
433
431
|
|
434
432
|
for q, q_new in zip(Q, new_qs):
|
435
|
-
|
433
|
+
if q is not None:
|
434
|
+
copy_stochastic_(q, q_new)
|
436
435
|
|
437
436
|
|
438
437
|
def get_orthogonal_matrix(mat):
|
@@ -442,8 +441,8 @@ def get_orthogonal_matrix(mat):
|
|
442
441
|
|
443
442
|
final = []
|
444
443
|
for m in mat:
|
445
|
-
if
|
446
|
-
final.append(
|
444
|
+
if m is None:
|
445
|
+
final.append(None)
|
447
446
|
continue
|
448
447
|
|
449
448
|
m = promote(m.data)
|
@@ -476,7 +475,9 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
476
475
|
for x_, y_ in zip(x, y):
|
477
476
|
x32 = promote(x_)
|
478
477
|
y32 = promote(y_)
|
479
|
-
|
478
|
+
if x32.dtype != y32.dtype:
|
479
|
+
y32 = y32.to(x32.dtype)
|
480
|
+
copy_stochastic_(x_, x32 * (1 - a) + y32 * a)
|
480
481
|
|
481
482
|
|
482
483
|
def get_beta1(group):
|
@@ -575,7 +576,7 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
575
576
|
g0 = einsum_base[:grad.dim()]
|
576
577
|
g1 = g0.replace(b, b.upper())
|
577
578
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
578
|
-
m
|
579
|
+
stochastic_lerp_(m, outer_product, 1 - beta)
|
579
580
|
|
580
581
|
|
581
582
|
def tree_apply(fn):
|
@@ -618,7 +619,8 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
|
618
619
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
619
620
|
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
620
621
|
for sh in grad.shape:
|
621
|
-
if sh > max_precond_dim:
|
622
|
+
if sh > max_precond_dim or sh == 1:
|
623
|
+
# via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
|
622
624
|
state['GG'].append(None)
|
623
625
|
else:
|
624
626
|
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
@@ -638,10 +640,10 @@ def project(grad, Q, back: bool):
|
|
638
640
|
:return:
|
639
641
|
"""
|
640
642
|
param = einsum_base[:grad.dim()]
|
641
|
-
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if
|
643
|
+
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
|
642
644
|
if preconditioners:
|
643
645
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
644
|
-
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if
|
646
|
+
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
|
645
647
|
grad = out.to(grad.dtype)
|
646
648
|
return grad
|
647
649
|
|
@@ -876,7 +878,7 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
|
876
878
|
ea32 = list(map(promote, state))
|
877
879
|
grad = list(map(promote, grad))
|
878
880
|
beta = promote(beta)
|
879
|
-
ea32
|
881
|
+
stochastic_lerp_(ea32, grad, 1 - beta)
|
880
882
|
copy_stochastic_list_(state, ea32)
|
881
883
|
return ea32
|
882
884
|
|
@@ -890,7 +892,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
890
892
|
g32 = list(map(promote, grad))
|
891
893
|
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
892
894
|
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
893
|
-
u32 =
|
895
|
+
u32 = torch._foreach_div(exp_avg32, denom)
|
894
896
|
copy_stochastic_list_(grad, u32)
|
895
897
|
|
896
898
|
|
@@ -973,14 +975,11 @@ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2,
|
|
973
975
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
974
976
|
|
975
977
|
beta1 = beta_debias(beta1, step)
|
976
|
-
denom =
|
977
|
-
|
978
|
-
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
979
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
978
|
+
denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
|
979
|
+
stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
|
980
980
|
|
981
981
|
beta2 = beta_debias(beta2, step + 1)
|
982
|
-
|
983
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
982
|
+
stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
|
984
983
|
|
985
984
|
|
986
985
|
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
@@ -990,27 +989,23 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
|
|
990
989
|
|
991
990
|
|
992
991
|
@decorator_knowngood
|
993
|
-
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
992
|
+
def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
|
994
993
|
g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
|
995
994
|
update = [e.clone() for e in exp_avg]
|
996
995
|
|
997
996
|
beta1 = beta_debias(beta1, step)
|
998
|
-
denom =
|
999
|
-
|
1000
|
-
exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
1001
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
997
|
+
denom = [eps_sqrt(d, eps) for d in exp_avg_sq32]
|
998
|
+
stochastic_lerp_(exp_avg, torch._foreach_div(g32, denom), 1 - beta1)
|
1002
999
|
|
1003
|
-
|
1004
|
-
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
1005
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
1000
|
+
stochastic_lerp_(exp_avg_sq, torch._foreach_mul(g32, g32), 1 - beta2)
|
1006
1001
|
|
1007
1002
|
copy_stochastic_list_(grad, update)
|
1008
1003
|
|
1009
1004
|
|
1010
|
-
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
1005
|
+
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps: float = 1e-8):
|
1011
1006
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
1012
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
1013
|
-
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
1007
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1008
|
+
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps)
|
1014
1009
|
return grad
|
1015
1010
|
|
1016
1011
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
+
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
+
heavyball/utils.py,sha256=iQxSQjw_sgJp4AvX71VdTJxJ_20Tdu7W2tdrYu5q2EI,55808
|
4
|
+
heavyball-1.6.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.6.1.dist-info/METADATA,sha256=yFMCDJPpD5jVOFtL4l_pM3jTw3_ZizeTSQ_ugVHIWKM,43441
|
6
|
+
heavyball-1.6.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.6.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.6.1.dist-info/RECORD,,
|
heavyball-1.6.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
-
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
-
heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
|
4
|
-
heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
|
6
|
-
heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.6.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|