heavyball 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +20 -1
- heavyball/chainable.py +50 -8
- heavyball/utils.py +589 -180
- {heavyball-1.7.0.dist-info → heavyball-1.7.2.dist-info}/METADATA +1 -1
- heavyball-1.7.2.dist-info/RECORD +8 -0
- heavyball-1.7.0.dist-info/RECORD +0 -8
- {heavyball-1.7.0.dist-info → heavyball-1.7.2.dist-info}/WHEEL +0 -0
- {heavyball-1.7.0.dist-info → heavyball-1.7.2.dist-info}/licenses/LICENSE +0 -0
- {heavyball-1.7.0.dist-info → heavyball-1.7.2.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import math
|
2
3
|
from typing import Optional
|
3
4
|
|
4
5
|
from . import chainable as C
|
@@ -564,6 +565,10 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
|
|
564
565
|
hessian_approx = True
|
565
566
|
|
566
567
|
|
568
|
+
class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
569
|
+
hvp_interval = 2
|
570
|
+
|
571
|
+
|
567
572
|
class ForeachPSGDLRA(C.BaseOpt):
|
568
573
|
"""
|
569
574
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -582,7 +587,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
582
587
|
weight_decay=0.0,
|
583
588
|
preconditioner_update_probability=None,
|
584
589
|
momentum_into_precond_update=True,
|
585
|
-
rank: int =
|
590
|
+
rank: Optional[int] = None,
|
586
591
|
warmup_steps: int = 0,
|
587
592
|
foreach: bool = True,
|
588
593
|
q_dtype="float32",
|
@@ -608,6 +613,14 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
608
613
|
)
|
609
614
|
params = defaults.pop("params")
|
610
615
|
|
616
|
+
if rank is None:
|
617
|
+
utils.warn_once(
|
618
|
+
f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
|
619
|
+
)
|
620
|
+
params = list(params)
|
621
|
+
defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
|
622
|
+
utils.warn_once(f"rank was set to {defaults['rank']}")
|
623
|
+
|
611
624
|
delayed = C.default(delayed, self.delayed)
|
612
625
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
613
626
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
@@ -632,6 +645,10 @@ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
|
|
632
645
|
hessian_approx = True
|
633
646
|
|
634
647
|
|
648
|
+
class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
|
649
|
+
hvp_interval = 2
|
650
|
+
|
651
|
+
|
635
652
|
PalmForEachSoap = PaLMForeachSOAP
|
636
653
|
PaLMSOAP = PaLMForeachSOAP
|
637
654
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -696,4 +713,6 @@ __all__ = [
|
|
696
713
|
"DelayedPSGD",
|
697
714
|
"PSGDLRA",
|
698
715
|
"NewtonPSGDLRA",
|
716
|
+
"NewtonHybrid2PSGDLRA",
|
717
|
+
"NewtonHybrid2PSGDKron",
|
699
718
|
]
|
heavyball/chainable.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import math
|
2
3
|
import random
|
3
4
|
from typing import List, Literal, Optional, Union
|
4
5
|
|
@@ -43,7 +44,7 @@ class FunctionTransform:
|
|
43
44
|
raise NotImplementedError
|
44
45
|
|
45
46
|
def get_fn(self):
|
46
|
-
if
|
47
|
+
if utils.hasattr_none(self.fn, "get_fn"):
|
47
48
|
return self.fn.get_fn()
|
48
49
|
return self.fn
|
49
50
|
|
@@ -426,7 +427,7 @@ def _store_std(state, group, update, grad, param):
|
|
426
427
|
state["init_std"] = torch.std(grad, dim=0)
|
427
428
|
|
428
429
|
|
429
|
-
@general_guard("init_std", init_fn=_store_std)
|
430
|
+
@general_guard("init_std", init_fn=_store_std, skip_first=False)
|
430
431
|
@no_state
|
431
432
|
def mup_approx(group, updates, grads, params, init_std):
|
432
433
|
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
@@ -435,6 +436,40 @@ def mup_approx(group, updates, grads, params, init_std):
|
|
435
436
|
return updates
|
436
437
|
|
437
438
|
|
439
|
+
def _init_delta(state, group, update, grad, param, log_space: bool):
|
440
|
+
val = group["initial_d"]
|
441
|
+
state["delta"] = torch.full((), math.log(val) if log_space else val, dtype=param.dtype, device=param.device)
|
442
|
+
|
443
|
+
|
444
|
+
def _init_full_delta(state, group, update, grad, param, log_space: bool):
|
445
|
+
val = group["initial_d"]
|
446
|
+
state["delta"] = torch.full_like(param, math.log(val) if log_space else val)
|
447
|
+
|
448
|
+
|
449
|
+
@zero_guard("state")
|
450
|
+
@general_guard("delta", init_fn=functools.partial(_init_delta, log_space=False), skip_first=False)
|
451
|
+
@no_state
|
452
|
+
def scale_by_d_adaptation(group, update, grad, param, state, delta):
|
453
|
+
utils.d_adaptation(grad, update, state, delta)
|
454
|
+
return update
|
455
|
+
|
456
|
+
|
457
|
+
@zero_guard("state")
|
458
|
+
@general_guard("delta", init_fn=functools.partial(_init_delta, log_space=True), skip_first=False)
|
459
|
+
@no_state
|
460
|
+
def scale_by_lr_adaptation(group, update, grad, param, state, delta):
|
461
|
+
utils.lr_adaptation(grad, update, state, delta, group["lr_lr"])
|
462
|
+
return update
|
463
|
+
|
464
|
+
|
465
|
+
@zero_guard("state")
|
466
|
+
@general_guard("delta", init_fn=functools.partial(_init_full_delta, log_space=True), skip_first=False)
|
467
|
+
@no_state
|
468
|
+
def scale_by_pointwise_lr_adaptation(group, update, grad, param, state, delta):
|
469
|
+
utils.pointwise_lr_adaptation(grad, update, state, delta, group["lr_lr"])
|
470
|
+
return update
|
471
|
+
|
472
|
+
|
438
473
|
@zero_guard("momentum")
|
439
474
|
@no_state
|
440
475
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
@@ -484,18 +519,22 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
484
519
|
if not group["is_preconditioning"]:
|
485
520
|
return Q_mat
|
486
521
|
|
522
|
+
if utils.hasattr_none(param, "vector"):
|
523
|
+
vector, hessian_vector = param.vector, param.hessian_vector
|
524
|
+
del param.vector
|
525
|
+
del param.hessian_vector
|
526
|
+
else:
|
527
|
+
vector, hessian_vector = utils.dampen_grad(grad)
|
528
|
+
|
487
529
|
utils.psgd_update_precond(
|
488
530
|
Q_mat,
|
489
531
|
exprs,
|
490
|
-
|
532
|
+
hessian_vector,
|
491
533
|
group["precond_lr"],
|
492
534
|
Q,
|
493
535
|
group["store_triu_as_line"],
|
494
|
-
|
536
|
+
vector,
|
495
537
|
)
|
496
|
-
if hasattr(param, "vector"):
|
497
|
-
del param.vector
|
498
|
-
del param.hessian_vector
|
499
538
|
|
500
539
|
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
501
540
|
if group["store_triu_as_line"]:
|
@@ -566,9 +605,12 @@ def _update_lra(
|
|
566
605
|
if not group["is_preconditioning"]:
|
567
606
|
return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
|
568
607
|
|
569
|
-
if
|
608
|
+
if utils.hasattr_none(params[0], "hessian_vector"):
|
570
609
|
vector = utils.flatten([p.vector for p in params])
|
571
610
|
hessian_vector = utils.flatten([p.hessian_vector for p in params])
|
611
|
+
for p in params:
|
612
|
+
del p.vector
|
613
|
+
del p.hessian_vector
|
572
614
|
else:
|
573
615
|
vector, hessian_vector = utils.dampen_multiple(grads)
|
574
616
|
return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
|
heavyball/utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
+
import contextlib
|
1
2
|
import functools
|
2
3
|
import gc
|
4
|
+
import inspect
|
3
5
|
import math
|
4
6
|
import random
|
7
|
+
import re
|
5
8
|
import string
|
6
9
|
import warnings
|
7
10
|
from typing import Callable, List, Optional, Tuple, Union
|
8
|
-
from unittest.mock import patch
|
9
11
|
|
10
12
|
import numpy as np
|
11
13
|
import torch
|
@@ -15,13 +17,22 @@ from torch._dynamo.exc import TorchDynamoException
|
|
15
17
|
from torch.backends import cudnn, opt_einsum
|
16
18
|
from torch.utils._pytree import tree_map
|
17
19
|
|
18
|
-
config.cache_size_limit = 2
|
20
|
+
config.cache_size_limit = 2**16
|
19
21
|
|
20
22
|
compile_mode = "max-autotune-no-cudagraphs"
|
21
23
|
dynamic = False
|
22
24
|
compile_mode_recommended_to_none = None
|
23
25
|
zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
|
24
26
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
27
|
+
_cudnn_double_backward_pattern = re.compile(
|
28
|
+
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
29
|
+
)
|
30
|
+
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
|
31
|
+
_fd_error = (
|
32
|
+
"You can accelerate startup by globally enabling finite_differences first " #
|
33
|
+
"(via opt.finite_differences=True or by subclassing it)\n"
|
34
|
+
"Original Error: "
|
35
|
+
)
|
25
36
|
|
26
37
|
|
27
38
|
def decorator(func):
|
@@ -39,7 +50,7 @@ def decorator(func):
|
|
39
50
|
return _fn
|
40
51
|
|
41
52
|
|
42
|
-
def decorator_knowngood(func: Callable):
|
53
|
+
def decorator_knowngood(func: Callable, fullgraph: bool = True):
|
43
54
|
compiled = None
|
44
55
|
|
45
56
|
@functools.wraps(func)
|
@@ -48,7 +59,7 @@ def decorator_knowngood(func: Callable):
|
|
48
59
|
return func(*args, **kwargs)
|
49
60
|
nonlocal compiled
|
50
61
|
if compiled is None:
|
51
|
-
compiled = torch.compile(fullgraph=
|
62
|
+
compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
|
52
63
|
return compiled(*args, **kwargs)
|
53
64
|
|
54
65
|
return _fn
|
@@ -58,8 +69,17 @@ einsum_base = string.ascii_lowercase
|
|
58
69
|
|
59
70
|
|
60
71
|
@decorator_knowngood
|
61
|
-
def _compilable_schedule_free_(
|
62
|
-
|
72
|
+
def _compilable_schedule_free_(
|
73
|
+
p: List[Tensor],
|
74
|
+
z: List[Tensor],
|
75
|
+
ckp1: Tensor,
|
76
|
+
update: List[Tensor],
|
77
|
+
lr: Tensor,
|
78
|
+
beta1: Tensor,
|
79
|
+
decay: float,
|
80
|
+
grad: List[Tensor],
|
81
|
+
caution,
|
82
|
+
):
|
63
83
|
for op, oz, u_, g_ in zip(p, z, update, grad):
|
64
84
|
u_ = u_.view_as(op)
|
65
85
|
p_, z_, u_ = map(promote, (op, oz, u_))
|
@@ -74,9 +94,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
|
|
74
94
|
copy_stochastic_(oz, z_)
|
75
95
|
|
76
96
|
|
77
|
-
def schedule_free_(
|
78
|
-
|
79
|
-
|
97
|
+
def schedule_free_(
|
98
|
+
lr: float,
|
99
|
+
weight_lr_power: float,
|
100
|
+
weight_sum: float,
|
101
|
+
beta1: float,
|
102
|
+
parameters: List[Tensor],
|
103
|
+
z: List[Tensor],
|
104
|
+
update: List[Tensor],
|
105
|
+
grad: List[Tensor],
|
106
|
+
caution: bool = False,
|
107
|
+
r: float = 0.0,
|
108
|
+
step: int = 0,
|
109
|
+
decay: float = 0.0,
|
110
|
+
):
|
80
111
|
weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
|
81
112
|
weight_sum = weight_sum + weight
|
82
113
|
|
@@ -149,7 +180,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
149
180
|
|
150
181
|
|
151
182
|
def beta_debias(beta, step):
|
152
|
-
return 1 - (1 - beta) / (1 - beta
|
183
|
+
return 1 - (1 - beta) / (1 - beta**step)
|
153
184
|
|
154
185
|
|
155
186
|
def eps_sqrt(item, eps):
|
@@ -157,8 +188,9 @@ def eps_sqrt(item, eps):
|
|
157
188
|
|
158
189
|
|
159
190
|
@decorator_knowngood
|
160
|
-
def _compilable_exp_avg_sq_(
|
161
|
-
|
191
|
+
def _compilable_exp_avg_sq_(
|
192
|
+
state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
|
193
|
+
):
|
162
194
|
g32 = promote(grad)
|
163
195
|
s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
|
164
196
|
|
@@ -219,8 +251,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
|
|
219
251
|
copy_stochastic_list_(gradients, g32)
|
220
252
|
|
221
253
|
|
222
|
-
def adaptive_gradient_clipping_(
|
223
|
-
|
254
|
+
def adaptive_gradient_clipping_(
|
255
|
+
parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8
|
256
|
+
):
|
224
257
|
if clip_val <= 0:
|
225
258
|
return gradients
|
226
259
|
parameters, gradients = list_guard(parameters, gradients)
|
@@ -259,9 +292,11 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
|
|
259
292
|
|
260
293
|
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
261
294
|
_ignore_warning(
|
262
|
-
"Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
|
295
|
+
"Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
|
296
|
+
)
|
263
297
|
_ignore_warning(
|
264
|
-
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
298
|
+
"We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
|
299
|
+
)
|
265
300
|
|
266
301
|
|
267
302
|
@decorator
|
@@ -408,7 +443,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
408
443
|
|
409
444
|
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
410
445
|
in_str = einsum_base[: exp_avg.dim()]
|
411
|
-
out_str = einsum_base[exp_avg.dim(): 2 * exp_avg.dim()]
|
446
|
+
out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()]
|
412
447
|
|
413
448
|
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
|
414
449
|
if not from_shampoo:
|
@@ -418,8 +453,9 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
|
|
418
453
|
out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
419
454
|
|
420
455
|
subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
|
421
|
-
exp_avg_new = torch.einsum(
|
422
|
-
*[q for q in new_qs if q is not None]
|
456
|
+
exp_avg_new = torch.einsum(
|
457
|
+
subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
|
458
|
+
)
|
423
459
|
copy_stochastic_(exp_avg, exp_avg_new)
|
424
460
|
|
425
461
|
for q, q_new in zip(Q, new_qs):
|
@@ -546,6 +582,20 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
|
|
546
582
|
_compilable_stochastic_add_(x, y, alpha)
|
547
583
|
|
548
584
|
|
585
|
+
@decorator_knowngood
|
586
|
+
def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Tensor, divisor: Tensor):
|
587
|
+
for x_, y_ in zip(x, y):
|
588
|
+
x32 = promote(x_)
|
589
|
+
y32 = promote(y_)
|
590
|
+
copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
|
591
|
+
|
592
|
+
|
593
|
+
def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
|
594
|
+
x, y = list_guard(x, y)
|
595
|
+
alpha, divisor = scalar_guard(alpha, divisor, x[0])
|
596
|
+
_compilable_stochastic_add_divide_(x, y, alpha, divisor)
|
597
|
+
|
598
|
+
|
549
599
|
@decorator_knowngood
|
550
600
|
def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
551
601
|
for x_, y_ in zip(x, y):
|
@@ -594,6 +644,20 @@ def promote(x):
|
|
594
644
|
return x
|
595
645
|
|
596
646
|
|
647
|
+
def promote_detach(x, should_promote):
|
648
|
+
if x is None:
|
649
|
+
return x
|
650
|
+
if should_promote:
|
651
|
+
x = promote(x)
|
652
|
+
return x.detach()
|
653
|
+
|
654
|
+
|
655
|
+
def detach(x):
|
656
|
+
if isinstance(x, Tensor):
|
657
|
+
return x.detach()
|
658
|
+
return x
|
659
|
+
|
660
|
+
|
597
661
|
def min_dtype(xs: List[Tensor]):
|
598
662
|
dtypes = [x.dtype for x in xs]
|
599
663
|
for d in (torch.float32, torch.bfloat16, torch.float16):
|
@@ -647,25 +711,36 @@ def project(grad, Q, back: bool):
|
|
647
711
|
return grad
|
648
712
|
|
649
713
|
|
650
|
-
|
651
|
-
|
652
|
-
|
714
|
+
@contextlib.contextmanager
|
715
|
+
def patch_backward():
|
716
|
+
@contextlib.contextmanager
|
717
|
+
def _inner(module):
|
718
|
+
original = module.backward
|
653
719
|
|
654
|
-
|
655
|
-
closure: The closure function passed to the optimizer.
|
720
|
+
signature = inspect.signature(original)
|
656
721
|
|
657
|
-
|
658
|
-
|
659
|
-
|
722
|
+
def patched_backward(*args, **kwargs):
|
723
|
+
new_kwargs = signature.bind(*args)
|
724
|
+
new_kwargs.apply_defaults()
|
725
|
+
new_kwargs = new_kwargs.arguments
|
726
|
+
new_kwargs.update(kwargs)
|
727
|
+
new_kwargs["create_graph"] = True
|
728
|
+
return original(**new_kwargs)
|
729
|
+
|
730
|
+
module.backward = patched_backward
|
731
|
+
yield
|
732
|
+
module.backward = original
|
733
|
+
|
734
|
+
with _inner(torch.Tensor), _inner(torch.autograd):
|
735
|
+
yield
|
660
736
|
|
661
|
-
def patched_backward(self, *args, **kwargs):
|
662
|
-
kwargs["create_graph"] = True
|
663
|
-
return original_backward(self, *args, **kwargs)
|
664
737
|
|
665
|
-
|
738
|
+
def hasattr_none(obj, name):
|
739
|
+
return getattr(obj, name, None) is not None
|
666
740
|
|
667
|
-
|
668
|
-
|
741
|
+
|
742
|
+
class ExactHVPFailed(ValueError):
|
743
|
+
pass
|
669
744
|
|
670
745
|
|
671
746
|
class StatefulOptimizer(torch.optim.Optimizer):
|
@@ -682,6 +757,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
682
757
|
precond_schedule: Union[Callable, float, None] = None
|
683
758
|
stochastic_schedule: bool = False
|
684
759
|
finite_differences: bool = False
|
760
|
+
fallback_to_finite_differences: bool = True
|
761
|
+
_fallback_enabled: bool = False
|
762
|
+
hvp_interval: int = 1 # grad is faster initially, hvp later
|
685
763
|
|
686
764
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
687
765
|
super().__init__(params, {**defaults, "foreach": foreach})
|
@@ -708,29 +786,46 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
708
786
|
old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
|
709
787
|
mars_correction(g_list, old_gs, mars_gamma, beta)
|
710
788
|
|
711
|
-
def split_p_and_g_in_group(
|
712
|
-
|
789
|
+
def split_p_and_g_in_group(
|
790
|
+
self,
|
791
|
+
group: dict,
|
792
|
+
skip_none: bool = True,
|
793
|
+
should_promote: bool = True,
|
794
|
+
beta1: float = -1.0,
|
795
|
+
raw: bool = False,
|
796
|
+
):
|
713
797
|
for p in group["params"]:
|
798
|
+
grad = getattr(p, "grad", None)
|
799
|
+
if grad is None and skip_none:
|
800
|
+
continue
|
801
|
+
|
802
|
+
p.grad = None
|
803
|
+
|
804
|
+
if raw:
|
805
|
+
yield p, grad
|
806
|
+
continue
|
807
|
+
|
714
808
|
if p in self.mapping:
|
715
809
|
p_views = self.mapping[p]
|
716
810
|
else:
|
717
811
|
self.mapping[p] = p_views = merge_group(group, p)
|
718
812
|
|
719
|
-
|
720
|
-
|
813
|
+
vector = getattr(p, "vector", None)
|
814
|
+
hessian_vector = getattr(p, "hessian_vector", None)
|
815
|
+
p.vector = None
|
816
|
+
p.hessian_vector = None
|
721
817
|
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
818
|
+
grad, vs, hvs = [
|
819
|
+
[None] * len(p_views) if x is None else merge_group(group, x) #
|
820
|
+
for x in (grad, vector, hessian_vector)
|
821
|
+
]
|
726
822
|
|
727
|
-
for pv, g in zip(p_views, grad):
|
728
|
-
|
729
|
-
continue
|
730
|
-
if should_promote:
|
731
|
-
g = promote(g)
|
823
|
+
for pv, g, v, hv in zip(p_views, grad, vs, hvs):
|
824
|
+
g = promote_detach(g, should_promote)
|
732
825
|
if beta1 >= 0 and group.get("mars", False):
|
733
826
|
self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
|
827
|
+
pv.vector = promote_detach(v, should_promote)
|
828
|
+
pv.hessian_vector = promote_detach(hv, should_promote)
|
734
829
|
yield pv, g
|
735
830
|
|
736
831
|
def state_size(self) -> int:
|
@@ -794,6 +889,66 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
794
889
|
set_(self.state_(p)["param_ema"], p.data)
|
795
890
|
set_(p.data, ema_clone)
|
796
891
|
|
892
|
+
def _finite_differences_hvp(self, closure):
|
893
|
+
with torch.enable_grad():
|
894
|
+
loss = closure() # closure without retain_graph=True
|
895
|
+
|
896
|
+
grads = []
|
897
|
+
for group in self.param_groups:
|
898
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
899
|
+
grads.append(g)
|
900
|
+
p.vector = torch.randn_like(p)
|
901
|
+
p.orig = p.data.clone()
|
902
|
+
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
|
903
|
+
stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
|
904
|
+
|
905
|
+
with torch.enable_grad():
|
906
|
+
closure()
|
907
|
+
|
908
|
+
# we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
|
909
|
+
# this costs more memory, but the imprecision seems too severe to use the other method
|
910
|
+
for group in self.param_groups:
|
911
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
912
|
+
p.grad = grads.pop(0)
|
913
|
+
stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
|
914
|
+
p.hessian_vector = g
|
915
|
+
p.data.copy_(p.orig)
|
916
|
+
del p.orig
|
917
|
+
return loss
|
918
|
+
|
919
|
+
def _double_backward_hvp(self, closure):
|
920
|
+
with torch.enable_grad(), patch_backward():
|
921
|
+
loss = closure()
|
922
|
+
|
923
|
+
params, grads = [], []
|
924
|
+
for group in self.param_groups:
|
925
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
|
926
|
+
params.append(p)
|
927
|
+
grads.append(g)
|
928
|
+
|
929
|
+
if not params:
|
930
|
+
raise ValueError("No parameter has gradients")
|
931
|
+
|
932
|
+
vs = [torch.randn_like(p) for p in params]
|
933
|
+
with torch.enable_grad():
|
934
|
+
try:
|
935
|
+
hvs = torch.autograd.grad(grads, params, vs, create_graph=False, retain_graph=False, allow_unused=True)
|
936
|
+
except RuntimeError as e:
|
937
|
+
raise ExactHVPFailed(str(e.args))
|
938
|
+
|
939
|
+
unused = []
|
940
|
+
for p, g, v, hv in zip(params, grads, vs, hvs):
|
941
|
+
p.hessian_vector = detach(hv)
|
942
|
+
p.grad = detach(g)
|
943
|
+
p.vector = detach(v)
|
944
|
+
if hv is None:
|
945
|
+
unused.append(list(p.shape))
|
946
|
+
|
947
|
+
if unused:
|
948
|
+
raise ExactHVPFailed(f"Parameters with the following shapes have no 2nd order derivative: {unused}")
|
949
|
+
|
950
|
+
return loss
|
951
|
+
|
797
952
|
def _handle_closure(self, closure):
|
798
953
|
hessian_approx = self.hessian_approx and self._is_preconditioning
|
799
954
|
|
@@ -802,56 +957,41 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
802
957
|
raise ValueError("Hessian approximation requires a closure.")
|
803
958
|
return None
|
804
959
|
|
805
|
-
|
960
|
+
step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
|
961
|
+
if not hessian_approx or step % self.hvp_interval == 0:
|
806
962
|
with torch.enable_grad():
|
807
963
|
loss = closure()
|
808
964
|
return loss
|
809
965
|
|
810
|
-
if self.finite_differences:
|
811
|
-
|
812
|
-
loss = closure() # closure without retain_graph=True
|
966
|
+
if self.finite_differences or self._fallback_enabled:
|
967
|
+
return self._finite_differences_hvp(closure)
|
813
968
|
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
for group in self.param_groups:
|
841
|
-
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
842
|
-
p.grad = g
|
843
|
-
params, grads = zip(*[x for group in self.param_groups for x in
|
844
|
-
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
845
|
-
vs = [torch.randn_like(p) for p in params]
|
846
|
-
with torch.enable_grad():
|
847
|
-
hvs = torch.autograd.grad(grads, params, vs)
|
848
|
-
|
849
|
-
for p, g, v, hv in zip(params, grads, vs, hvs):
|
850
|
-
p.hessian_vector = hv
|
851
|
-
p.grad = g
|
852
|
-
p.vector = v
|
853
|
-
|
854
|
-
return loss
|
969
|
+
try:
|
970
|
+
return self._double_backward_hvp(closure)
|
971
|
+
except NotImplementedError as e:
|
972
|
+
if not self.fallback_to_finite_differences:
|
973
|
+
raise
|
974
|
+
if not any(isinstance(arg, str) and _cudnn_double_backward_pattern.match(arg) for arg in e.args):
|
975
|
+
raise
|
976
|
+
warn_once(
|
977
|
+
"CUDNN doesn't support double-backward for some models (including RNNs). " #
|
978
|
+
f"Falling back to finite_differences.\n{_fd_error}{e}"
|
979
|
+
)
|
980
|
+
except RuntimeError as e:
|
981
|
+
if not self.fallback_to_finite_differences:
|
982
|
+
raise
|
983
|
+
if not any(isinstance(arg, str) and _torch_compile_double_backward_pattern.match(arg) for arg in e.args):
|
984
|
+
raise
|
985
|
+
warn_once(
|
986
|
+
f"torch.compile does not support double-backward. Disabling it may be beneficial, depending on "
|
987
|
+
f"the model.\n{_fd_error}{e}"
|
988
|
+
)
|
989
|
+
except ExactHVPFailed as e:
|
990
|
+
if not self.fallback_to_finite_differences:
|
991
|
+
raise
|
992
|
+
warn_once(f"Exact HVP calculation failed.\n{_fd_error}{e}")
|
993
|
+
self._fallback_enabled = True
|
994
|
+
return self._handle_closure(closure)
|
855
995
|
|
856
996
|
def step(self, closure: Optional[Callable] = None):
|
857
997
|
if self.precond_schedule is None:
|
@@ -867,7 +1007,11 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
867
1007
|
self._step(group)
|
868
1008
|
if self.use_ema:
|
869
1009
|
self.ema_update()
|
870
|
-
|
1010
|
+
for real, views in self.mapping.items():
|
1011
|
+
for tensor in (real, *views):
|
1012
|
+
for key in ("grad", "vector", "hessian_vector", "orig"):
|
1013
|
+
if hasattr(tensor, key):
|
1014
|
+
setattr(tensor, key, None)
|
871
1015
|
return loss
|
872
1016
|
|
873
1017
|
|
@@ -887,8 +1031,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
|
887
1031
|
|
888
1032
|
|
889
1033
|
@decorator_knowngood
|
890
|
-
def _compilable_adam_(
|
891
|
-
|
1034
|
+
def _compilable_adam_(
|
1035
|
+
exp_avg: List[Tensor],
|
1036
|
+
exp_avg_sq: List[Tensor],
|
1037
|
+
grad: List[Tensor],
|
1038
|
+
beta1: Tensor,
|
1039
|
+
beta2: Tensor,
|
1040
|
+
step: Tensor,
|
1041
|
+
eps: Tensor,
|
1042
|
+
):
|
892
1043
|
beta1 = beta_debias(beta1, step)
|
893
1044
|
beta2 = beta_debias(beta2, step)
|
894
1045
|
|
@@ -899,8 +1050,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
899
1050
|
copy_stochastic_list_(grad, u32)
|
900
1051
|
|
901
1052
|
|
902
|
-
def adam_(
|
903
|
-
|
1053
|
+
def adam_(
|
1054
|
+
exp_avg: List[Tensor],
|
1055
|
+
exp_avg_sq: List[Tensor],
|
1056
|
+
grad: List[Tensor],
|
1057
|
+
beta1: float,
|
1058
|
+
beta2: float,
|
1059
|
+
step: int,
|
1060
|
+
eps: float = 1e-8,
|
1061
|
+
):
|
904
1062
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
905
1063
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
906
1064
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -908,9 +1066,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
|
|
908
1066
|
|
909
1067
|
|
910
1068
|
@decorator_knowngood
|
911
|
-
def _fused_compilable_adam_(
|
912
|
-
|
913
|
-
|
1069
|
+
def _fused_compilable_adam_(
|
1070
|
+
y: List[Tensor],
|
1071
|
+
exp_avg: List[Tensor],
|
1072
|
+
exp_avg_sq: List[Tensor],
|
1073
|
+
update: List[Tensor],
|
1074
|
+
grad: List[Tensor],
|
1075
|
+
beta1: Tensor,
|
1076
|
+
beta2: Tensor,
|
1077
|
+
step: Tensor,
|
1078
|
+
decay: Tensor,
|
1079
|
+
lr: Tensor,
|
1080
|
+
eps: Tensor,
|
1081
|
+
caution: bool,
|
1082
|
+
):
|
914
1083
|
beta1 = beta_debias(beta1, step)
|
915
1084
|
beta2 = beta_debias(beta2, step)
|
916
1085
|
|
@@ -921,17 +1090,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
921
1090
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
922
1091
|
|
923
1092
|
|
924
|
-
def fused_adam_(
|
925
|
-
|
926
|
-
|
1093
|
+
def fused_adam_(
|
1094
|
+
y: List[Tensor],
|
1095
|
+
exp_avg: List[Tensor],
|
1096
|
+
exp_avg_sq: List[Tensor],
|
1097
|
+
update: List[Tensor],
|
1098
|
+
grad: List[Tensor],
|
1099
|
+
beta1: float,
|
1100
|
+
beta2: float,
|
1101
|
+
step: int,
|
1102
|
+
lr: float,
|
1103
|
+
eps: float,
|
1104
|
+
decay: float,
|
1105
|
+
caution: bool,
|
1106
|
+
):
|
927
1107
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
928
1108
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
929
1109
|
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
930
1110
|
|
931
1111
|
|
932
1112
|
@decorator_knowngood
|
933
|
-
def _compilable_laprop_(
|
934
|
-
|
1113
|
+
def _compilable_laprop_(
|
1114
|
+
exp_avg: List[Tensor],
|
1115
|
+
exp_avg_sq: List[Tensor],
|
1116
|
+
grad: List[Tensor],
|
1117
|
+
beta1: Tensor,
|
1118
|
+
beta2: Tensor,
|
1119
|
+
step: Tensor,
|
1120
|
+
eps: Tensor,
|
1121
|
+
):
|
935
1122
|
beta1 = beta_debias(beta1, step)
|
936
1123
|
beta2 = beta_debias(beta2, step)
|
937
1124
|
|
@@ -942,8 +1129,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
942
1129
|
copy_stochastic_list_(grad, gp32)
|
943
1130
|
|
944
1131
|
|
945
|
-
def laprop_(
|
946
|
-
|
1132
|
+
def laprop_(
|
1133
|
+
exp_avg: List[Tensor],
|
1134
|
+
exp_avg_sq: List[Tensor],
|
1135
|
+
grad: List[Tensor],
|
1136
|
+
beta1: float,
|
1137
|
+
beta2: float,
|
1138
|
+
step: int,
|
1139
|
+
eps: float = 1e-8,
|
1140
|
+
):
|
947
1141
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
948
1142
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
949
1143
|
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -951,9 +1145,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
951
1145
|
|
952
1146
|
|
953
1147
|
@decorator_knowngood
|
954
|
-
def _fused_compilable_laprop_(
|
955
|
-
|
956
|
-
|
1148
|
+
def _fused_compilable_laprop_(
|
1149
|
+
y: List[Tensor],
|
1150
|
+
exp_avg: List[Tensor],
|
1151
|
+
exp_avg_sq: List[Tensor],
|
1152
|
+
update: List[Tensor],
|
1153
|
+
grad: List[Tensor],
|
1154
|
+
beta1: Tensor,
|
1155
|
+
beta2: Tensor,
|
1156
|
+
step: Tensor,
|
1157
|
+
lr: Tensor,
|
1158
|
+
decay: Tensor,
|
1159
|
+
caution: bool,
|
1160
|
+
eps: Tensor,
|
1161
|
+
):
|
957
1162
|
beta1 = beta_debias(beta1, step)
|
958
1163
|
beta2 = beta_debias(beta2, step)
|
959
1164
|
|
@@ -964,9 +1169,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
|
|
964
1169
|
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
965
1170
|
|
966
1171
|
|
967
|
-
def fused_laprop_(
|
968
|
-
|
969
|
-
|
1172
|
+
def fused_laprop_(
|
1173
|
+
y: List[Tensor],
|
1174
|
+
exp_avg: List[Tensor],
|
1175
|
+
exp_avg_sq: List[Tensor],
|
1176
|
+
update: List[Tensor],
|
1177
|
+
grad: List[Tensor],
|
1178
|
+
beta1: float,
|
1179
|
+
beta2: float,
|
1180
|
+
step: int,
|
1181
|
+
lr: float,
|
1182
|
+
decay: float,
|
1183
|
+
caution: bool,
|
1184
|
+
eps: float = 1e-8,
|
1185
|
+
):
|
970
1186
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
971
1187
|
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
972
1188
|
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
@@ -1040,8 +1256,9 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
1040
1256
|
|
1041
1257
|
|
1042
1258
|
@decorator_knowngood
|
1043
|
-
def _compilable_update_(
|
1044
|
-
|
1259
|
+
def _compilable_update_(
|
1260
|
+
p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]]
|
1261
|
+
):
|
1045
1262
|
for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
|
1046
1263
|
u_ = promote(u_.view_as(p_))
|
1047
1264
|
p32_ = promote(p_)
|
@@ -1051,8 +1268,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
|
|
1051
1268
|
copy_stochastic_(p_, p32_)
|
1052
1269
|
|
1053
1270
|
|
1054
|
-
def update_param_(
|
1055
|
-
|
1271
|
+
def update_param_(
|
1272
|
+
param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None
|
1273
|
+
):
|
1056
1274
|
param, update, grad = list_guard(param, update, grad)
|
1057
1275
|
lr = scalar_guard(lr, param[0])
|
1058
1276
|
if not caution:
|
@@ -1076,28 +1294,83 @@ def _max_idx(x: List[int]):
|
|
1076
1294
|
|
1077
1295
|
|
1078
1296
|
@decorator_knowngood
|
1079
|
-
def
|
1080
|
-
|
1297
|
+
def stable_exp(x: Tensor):
|
1298
|
+
# fp16:
|
1299
|
+
# exp(x) is stable in [-17, 11]
|
1300
|
+
# `stable_exp` extends to [-17, 17]
|
1301
|
+
# average error (in [-10, 10]) increased from 2.288e-3 to 2.299e-3
|
1302
|
+
# fp32:
|
1303
|
+
# exp(x) is stable in [-103, 88]
|
1304
|
+
# `stable_exp` extends to [-103, 103]
|
1305
|
+
# average error (in [-87, 87]) reduced from 3.309-06 to 3.224-06
|
1306
|
+
return torch.where(x > 0, 1 / (-x).exp(), x.exp())
|
1307
|
+
|
1308
|
+
|
1309
|
+
@decorator_knowngood
|
1310
|
+
def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
|
1311
|
+
# 1 / (mean(x ** pow) ** (1 / pow / 2))
|
1312
|
+
log_x = x.double().abs().clamp(min=eps).log()
|
1313
|
+
log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
|
1314
|
+
return stable_exp(-log_mean_x_pow / pow / 2)
|
1081
1315
|
|
1082
1316
|
|
1083
1317
|
@decorator_knowngood
|
1084
|
-
def divided_root(x, y, pow0, pow1):
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1318
|
+
def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
|
1319
|
+
# mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
|
1320
|
+
log_x = x.double().abs().clamp(min=eps).log()
|
1321
|
+
log_y = y.double().abs().clamp(min=eps).log()
|
1322
|
+
|
1323
|
+
x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
|
1324
|
+
x_normed = x_normed / pow0 / 2
|
1325
|
+
|
1326
|
+
y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
|
1327
|
+
y_normed = y_normed / pow1 / 2
|
1328
|
+
|
1329
|
+
return stable_exp(x_normed - y_normed)
|
1088
1330
|
|
1089
1331
|
|
1090
|
-
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector):
|
1332
|
+
def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
|
1333
|
+
automatic_scale = True
|
1334
|
+
manual_hint = " Set it manually using `precond_init_scale=0.1`"
|
1335
|
+
|
1091
1336
|
if scale is not None:
|
1337
|
+
automatic_scale = False
|
1092
1338
|
warn_once(
|
1093
|
-
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1094
|
-
|
1339
|
+
"It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
|
1340
|
+
)
|
1341
|
+
if scale_scale is not None and scale_scale != 1:
|
1095
1342
|
warn_once(
|
1096
|
-
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
|
1343
|
+
"precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
|
1344
|
+
)
|
1345
|
+
elif hessian_vector is None:
|
1346
|
+
scale = mean_root(grad, 4) * scale_scale
|
1347
|
+
else:
|
1348
|
+
scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1349
|
+
|
1350
|
+
if isinstance(scale, torch.Tensor):
|
1351
|
+
scale = scale.item() # slow, but necessary
|
1352
|
+
|
1353
|
+
if np.isfinite(scale):
|
1354
|
+
if scale > scale_max or scale < 1 / scale_max: # fallthrough to later checks
|
1355
|
+
warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
|
1356
|
+
else:
|
1357
|
+
return scale
|
1358
|
+
|
1359
|
+
if not automatic_scale:
|
1360
|
+
raise ValueError("The manually set precond_init_scale is not finite")
|
1361
|
+
|
1362
|
+
for x in (grad, hessian_vector, vector):
|
1363
|
+
if x is None:
|
1364
|
+
continue
|
1365
|
+
if torch.allclose(x, torch.zeros_like(x)).item():
|
1366
|
+
raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
|
1367
|
+
if not torch.isfinite(x).all().item():
|
1368
|
+
raise ValueError("Grad or HVP is not finite")
|
1369
|
+
|
1370
|
+
if np.isfinite(scale):
|
1097
1371
|
return scale
|
1098
|
-
|
1099
|
-
|
1100
|
-
return divided_root(vector, hessian_vector, 2, 4) * scale_scale
|
1372
|
+
|
1373
|
+
raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
|
1101
1374
|
|
1102
1375
|
|
1103
1376
|
def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
|
@@ -1108,8 +1381,9 @@ def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None)
|
|
1108
1381
|
return U, V, d
|
1109
1382
|
|
1110
1383
|
|
1111
|
-
def init_Q_exprs(
|
1112
|
-
|
1384
|
+
def init_Q_exprs(
|
1385
|
+
grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
|
1386
|
+
):
|
1113
1387
|
"""
|
1114
1388
|
For a scalar or tensor `grad`, we initialize its preconditioner Q and
|
1115
1389
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
@@ -1147,8 +1421,10 @@ def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory
|
|
1147
1421
|
elif memory_save_mode == "all_diag":
|
1148
1422
|
dim_diag = [True for _ in shape]
|
1149
1423
|
else:
|
1150
|
-
raise ValueError(
|
1151
|
-
|
1424
|
+
raise ValueError(
|
1425
|
+
f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
|
1426
|
+
"[None, 'one_diag', 'all_diag', 'smart_one_diag']"
|
1427
|
+
)
|
1152
1428
|
|
1153
1429
|
Q = []
|
1154
1430
|
piece1A, piece2A, piece3A = ([], "", "")
|
@@ -1213,8 +1489,16 @@ def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
|
|
1213
1489
|
return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
|
1214
1490
|
|
1215
1491
|
|
1216
|
-
def update_lra_precond_(
|
1217
|
-
|
1492
|
+
def update_lra_precond_(
|
1493
|
+
U: List[Tensor],
|
1494
|
+
V: List[Tensor],
|
1495
|
+
d: List[Tensor],
|
1496
|
+
vector: Tensor,
|
1497
|
+
hessian_vector: Tensor,
|
1498
|
+
eps: float,
|
1499
|
+
step: float,
|
1500
|
+
delayed: bool,
|
1501
|
+
):
|
1218
1502
|
"""
|
1219
1503
|
Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
|
1220
1504
|
"""
|
@@ -1293,7 +1577,7 @@ def lra_precond(U, V, d, g):
|
|
1293
1577
|
|
1294
1578
|
|
1295
1579
|
@decorator_knowngood
|
1296
|
-
def dampen_grad(g: Tensor, damp: float = 2
|
1580
|
+
def dampen_grad(g: Tensor, damp: float = 2**-13):
|
1297
1581
|
# https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
|
1298
1582
|
v = torch.randn_like(g)
|
1299
1583
|
return v, g + damp * g.abs().mean() * v
|
@@ -1306,7 +1590,7 @@ def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor,
|
|
1306
1590
|
update = update.flatten()
|
1307
1591
|
for p in params:
|
1308
1592
|
size = p.numel()
|
1309
|
-
copy_stochastic_(p, update[start: start + size].view_as(p))
|
1593
|
+
copy_stochastic_(p, update[start : start + size].view_as(p))
|
1310
1594
|
start += size
|
1311
1595
|
|
1312
1596
|
|
@@ -1316,7 +1600,7 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
|
|
1316
1600
|
update = update.flatten()
|
1317
1601
|
for p in params:
|
1318
1602
|
size = p.numel()
|
1319
|
-
copy_stochastic_(p, update[start: start + size].view_as(p))
|
1603
|
+
copy_stochastic_(p, update[start : start + size].view_as(p))
|
1320
1604
|
start += size
|
1321
1605
|
|
1322
1606
|
|
@@ -1326,7 +1610,7 @@ def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
|
|
1326
1610
|
update = update.flatten()
|
1327
1611
|
for p in params:
|
1328
1612
|
size = p.numel()
|
1329
|
-
stochastic_add_([p], [update[start: start + size].view_as(p)], alpha)
|
1613
|
+
stochastic_add_([p], [update[start : start + size].view_as(p)], alpha)
|
1330
1614
|
start += size
|
1331
1615
|
|
1332
1616
|
|
@@ -1337,16 +1621,19 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
|
|
1337
1621
|
update = update.flatten()
|
1338
1622
|
for p in params:
|
1339
1623
|
size = p.numel()
|
1340
|
-
outputs.append(update[start: start + size].view_as(p))
|
1624
|
+
outputs.append(update[start : start + size].view_as(p))
|
1341
1625
|
start += size
|
1342
1626
|
return outputs
|
1343
1627
|
|
1344
1628
|
|
1629
|
+
@decorator_knowngood
|
1345
1630
|
def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
|
1346
|
-
|
1631
|
+
last_dim = x[0].shape[-remaining:] if remaining else []
|
1632
|
+
return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
|
1347
1633
|
|
1348
1634
|
|
1349
|
-
|
1635
|
+
@decorator_knowngood
|
1636
|
+
def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
|
1350
1637
|
vs = []
|
1351
1638
|
gs = []
|
1352
1639
|
for g_ in g:
|
@@ -1356,30 +1643,58 @@ def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
|
|
1356
1643
|
return flatten(vs), flatten(gs)
|
1357
1644
|
|
1358
1645
|
|
1359
|
-
def
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
for i, q in enumerate(
|
1646
|
+
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
|
1647
|
+
md = min_dtype(args)
|
1648
|
+
return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
|
1649
|
+
|
1650
|
+
|
1651
|
+
@decorator_knowngood
|
1652
|
+
def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
|
1653
|
+
triangular_qs = []
|
1654
|
+
for i, q in enumerate(Qs):
|
1655
|
+
q = promote(q)
|
1368
1656
|
if q.dim() <= 1:
|
1369
|
-
|
1657
|
+
shape = [1] * conjB.ndim
|
1658
|
+
shape[i] = -1
|
1659
|
+
conjB /= q.view(shape)
|
1370
1660
|
else:
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1661
|
+
triangular_qs.append((i, q))
|
1662
|
+
return triangular_qs
|
1663
|
+
|
1664
|
+
|
1665
|
+
@decorator_knowngood
|
1666
|
+
def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int, new_shape: int):
|
1667
|
+
solved = solved.reshape(original_shape)
|
1668
|
+
solved.transpose(last_dim, -1)
|
1669
|
+
return solved.reshape(new_shape).contiguous()
|
1670
|
+
|
1671
|
+
|
1672
|
+
def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
|
1673
|
+
order = G.dim()
|
1674
|
+
if order > 1:
|
1675
|
+
conjB = conjB.view_as(G).permute(*range(1, order), 0)
|
1676
|
+
conjB = conjB.to(promote(G.dtype))
|
1677
|
+
A = casted_einsum(exprA, *Q, G)
|
1678
|
+
solve = torch.compiler.disable(torch.linalg.solve_triangular)
|
1679
|
+
original_shape = conjB.shape
|
1680
|
+
prev_i = -1
|
1681
|
+
for i, tri_q in _psgd_calc_scalars_(Q, conjB):
|
1682
|
+
conjB = _reshape_conjB(conjB, original_shape, prev_i, [-1, tri_q.size(0)])
|
1683
|
+
prev_i = i
|
1684
|
+
conjB = solve(tri_q, conjB, upper=True, left=False)
|
1685
|
+
conjB = _reshape_conjB(conjB, original_shape, prev_i, original_shape)
|
1375
1686
|
return A, conjB
|
1376
1687
|
|
1377
1688
|
|
1378
|
-
|
1689
|
+
@decorator_knowngood
|
1690
|
+
def _max_select(to_index: Tensor, to_argmax: Tensor):
|
1691
|
+
idx = to_argmax.argmax()
|
1692
|
+
return to_index.index_select(1, idx).flatten().contiguous()
|
1693
|
+
|
1694
|
+
|
1695
|
+
def psgd_lb(A: Tensor, max_abs: Tensor):
|
1379
1696
|
A /= max_abs
|
1380
|
-
|
1381
|
-
i = torch.argmax(a0)
|
1382
|
-
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1697
|
+
x = _max_select(A, torch.einsum("ij,ij->j", A, A))
|
1383
1698
|
x = torch.einsum("i,ij->j", x, A)
|
1384
1699
|
x /= x.norm()
|
1385
1700
|
x = torch.einsum("j,kj->k", x, A)
|
@@ -1388,28 +1703,52 @@ def psgd_lb(A, max_abs):
|
|
1388
1703
|
return x
|
1389
1704
|
|
1390
1705
|
|
1706
|
+
@decorator_knowngood
|
1707
|
+
def _subtract_from_line_(state: Tensor, term: Tensor):
|
1708
|
+
stochastic_add_([state], [triu_to_line([term])[0][1]], -1)
|
1709
|
+
|
1710
|
+
|
1711
|
+
@decorator_knowngood
|
1712
|
+
def _prescale_term_(term1: Tensor, fac: Tensor, norm: Tensor, lower_bound: Tensor):
|
1713
|
+
out = term1.float().triu() * fac
|
1714
|
+
out = out / torch.where(norm > 0, lower_bound, norm).clamp(tiny_bf16)
|
1715
|
+
copy_stochastic_(term1, out)
|
1716
|
+
|
1717
|
+
|
1718
|
+
@decorator_knowngood
|
1719
|
+
def _compilable_stochastic_multiply_div_(x: Tensor, fac: Tensor, y: Tensor, z: Tensor):
|
1720
|
+
copy_stochastic_(x, promote(x) * promote(fac) * promote(y) / promote(z).clamp(min=tiny_bf16))
|
1721
|
+
|
1722
|
+
|
1723
|
+
@decorator_knowngood
|
1724
|
+
def _compilable_add_sub_(x: Tensor, y: Tensor):
|
1725
|
+
x = promote(x)
|
1726
|
+
y = promote(y)
|
1727
|
+
return x - y, x + y
|
1728
|
+
|
1729
|
+
|
1391
1730
|
@decorator
|
1392
1731
|
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
1393
1732
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1394
1733
|
exprA, exprGs, _ = exprs
|
1395
1734
|
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
|
1735
|
+
precond_lr = scalar_guard(precond_lr, G)
|
1396
1736
|
|
1397
1737
|
for q, exprG, o in zip(Q, exprGs, oq):
|
1398
|
-
term1 =
|
1399
|
-
term2 =
|
1400
|
-
term1, term2 = term1
|
1401
|
-
term1 *= precond_lr
|
1738
|
+
term1 = torch.einsum(exprG, A, A)
|
1739
|
+
term2 = torch.einsum(exprG, conjB, conjB)
|
1740
|
+
term1, term2 = _compilable_add_sub_(term1, term2)
|
1402
1741
|
norm = term2.norm(float("inf"))
|
1403
1742
|
if q.dim() < 2:
|
1404
|
-
term1
|
1743
|
+
_compilable_stochastic_multiply_div_(term1, precond_lr, q, norm)
|
1405
1744
|
else:
|
1406
|
-
|
1407
|
-
term1
|
1408
|
-
|
1745
|
+
lower_bound = psgd_lb(term2, norm)
|
1746
|
+
_prescale_term_(term1, precond_lr, lower_bound, norm)
|
1747
|
+
torch.mm(term1, q.to(term1.dtype), out=term1)
|
1409
1748
|
if store_triu_as_line:
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1749
|
+
_subtract_from_line_(q, term1)
|
1750
|
+
else:
|
1751
|
+
stochastic_add_(o, term1, -1)
|
1413
1752
|
|
1414
1753
|
|
1415
1754
|
@decorator_knowngood
|
@@ -1619,8 +1958,9 @@ def warn_once(msg):
|
|
1619
1958
|
_warned.add(msg)
|
1620
1959
|
|
1621
1960
|
|
1622
|
-
def psgd_should_update(
|
1623
|
-
|
1961
|
+
def psgd_should_update(
|
1962
|
+
group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
|
1963
|
+
):
|
1624
1964
|
group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
|
1625
1965
|
if not isinstance(prob, float):
|
1626
1966
|
prob = prob(group[f"{name}_prob_step"])
|
@@ -1632,8 +1972,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1632
1972
|
|
1633
1973
|
|
1634
1974
|
@decorator_knowngood
|
1635
|
-
def precond_grad_cached_(
|
1636
|
-
|
1975
|
+
def precond_grad_cached_(
|
1976
|
+
expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
|
1977
|
+
):
|
1637
1978
|
if caution:
|
1638
1979
|
ea = _compilable_cautioning(grad, ea)
|
1639
1980
|
md = min_dtype(list(cached_q) + [ea])
|
@@ -1753,12 +2094,79 @@ def merge_group(group, *tensors):
|
|
1753
2094
|
|
1754
2095
|
out = []
|
1755
2096
|
for t in tensors:
|
1756
|
-
append_or_extend(
|
1757
|
-
|
1758
|
-
|
2097
|
+
append_or_extend(
|
2098
|
+
out,
|
2099
|
+
dim_merger(
|
2100
|
+
t,
|
2101
|
+
group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
|
2102
|
+
group.get("split", False),
|
2103
|
+
),
|
2104
|
+
)
|
1759
2105
|
return out
|
1760
2106
|
|
1761
2107
|
|
2108
|
+
@decorator_knowngood
|
2109
|
+
def _compilable_d_adapt_(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
|
2110
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2111
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2112
|
+
next_d = d * (g * s).sum()
|
2113
|
+
s = s + u * d
|
2114
|
+
next_d = next_d / s.abs().sum()
|
2115
|
+
next_d = torch.maximum(next_d, d)
|
2116
|
+
copy_stochastic_(u_, u * d)
|
2117
|
+
copy_stochastic_(d_, next_d)
|
2118
|
+
copy_stochastic_(s_, s)
|
2119
|
+
|
2120
|
+
|
2121
|
+
def d_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
|
2122
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2123
|
+
_compilable_d_adapt_(grads, update, state, delta)
|
2124
|
+
|
2125
|
+
|
2126
|
+
@decorator_knowngood
|
2127
|
+
def _compilable_lr_adapt_(
|
2128
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
|
2129
|
+
):
|
2130
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2131
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2132
|
+
lr_grad = d.sigmoid()
|
2133
|
+
lr_grad = lr_grad * (1 - lr_grad)
|
2134
|
+
lr_grad = lr_grad * (s * g).mean()
|
2135
|
+
d = d - lr_grad * lr_lr
|
2136
|
+
copy_stochastic_(d_, d)
|
2137
|
+
copy_stochastic_(u_, u * d.sigmoid())
|
2138
|
+
copy_stochastic_(s_, u)
|
2139
|
+
|
2140
|
+
|
2141
|
+
def lr_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float):
|
2142
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2143
|
+
lr_lr = scalar_guard(lr_lr, grads[0])
|
2144
|
+
_compilable_lr_adapt_(grads, update, state, delta, lr_lr)
|
2145
|
+
|
2146
|
+
|
2147
|
+
@decorator_knowngood
|
2148
|
+
def _compilable_pointwise_lr_adapt_(
|
2149
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
|
2150
|
+
):
|
2151
|
+
for g_, u_, s_, d_ in zip(grads, update, state, delta):
|
2152
|
+
g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
|
2153
|
+
lr_grad = d.sigmoid()
|
2154
|
+
lr_grad = lr_grad * (1 - lr_grad)
|
2155
|
+
lr_grad = lr_grad * s * g
|
2156
|
+
d = d - lr_grad * lr_lr
|
2157
|
+
copy_stochastic_(d_, d)
|
2158
|
+
copy_stochastic_(u_, u * d.sigmoid())
|
2159
|
+
copy_stochastic_(s_, u)
|
2160
|
+
|
2161
|
+
|
2162
|
+
def pointwise_lr_adaptation(
|
2163
|
+
grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float
|
2164
|
+
):
|
2165
|
+
grads, update, state, delta = list_guard(grads, update, state, delta)
|
2166
|
+
lr_lr = scalar_guard(lr_lr, grads[0])
|
2167
|
+
_compilable_lr_adapt_(grads, update, state, delta, lr_lr)
|
2168
|
+
|
2169
|
+
|
1762
2170
|
def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
|
1763
2171
|
optimizers = {}
|
1764
2172
|
|
@@ -1781,8 +2189,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
|
|
1781
2189
|
|
1782
2190
|
o = optimizer(parameters, *args, **kwargs)
|
1783
2191
|
step_fn = o.step
|
1784
|
-
o.step = functools.partial(
|
1785
|
-
msg="You're trying to call `step` on a fused optimizer. This will not do anything."
|
2192
|
+
o.step = functools.partial(
|
2193
|
+
warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything."
|
2194
|
+
)
|
1786
2195
|
|
1787
2196
|
def _step(p: Tensor):
|
1788
2197
|
seen_params.add(p)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=tt0QMvIbU6IRDexpjSWmWdNEVfYvsPT6-hAWfKrbDQc,20379
|
2
|
+
heavyball/chainable.py,sha256=jkiTzaXFjEMJztN3TRGkBV7s0-deCakmR1QGIZHb54o,32635
|
3
|
+
heavyball/utils.py,sha256=Y7YkYQhyUEZFUcTPQv6hrAL1gPE9oSydkuIEW5_LxbY,73545
|
4
|
+
heavyball-1.7.2.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
5
|
+
heavyball-1.7.2.dist-info/METADATA,sha256=iUY20QhT8d6hnb1udkOUnQyfRN_r8MM3Vhb0aq5eGNI,43718
|
6
|
+
heavyball-1.7.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
7
|
+
heavyball-1.7.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.7.2.dist-info/RECORD,,
|
heavyball-1.7.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=64gbqGEWM0zxfXDCOAcB1VtIPd3sdzAOSNHBCzSg8uQ,19762
|
2
|
-
heavyball/chainable.py,sha256=XCsBgBZtmd4swQSCtMmEpQtpsPbiJc18RAvaW9rlkIs,31174
|
3
|
-
heavyball/utils.py,sha256=Uj3L-x5a56_G3G_VqqOrU7y098lxjkdjIwkKA7L5ETQ,62759
|
4
|
-
heavyball-1.7.0.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
|
5
|
-
heavyball-1.7.0.dist-info/METADATA,sha256=a8Aar_g95j_wZNL59vYc0BkIHwn49_RjDtflKON-HmQ,43718
|
6
|
-
heavyball-1.7.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
7
|
-
heavyball-1.7.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.7.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|