heavyball 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  from typing import Optional
3
4
 
4
5
  from . import chainable as C
@@ -564,6 +565,10 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
564
565
  hessian_approx = True
565
566
 
566
567
 
568
+ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
569
+ hvp_interval = 2
570
+
571
+
567
572
  class ForeachPSGDLRA(C.BaseOpt):
568
573
  """
569
574
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -582,7 +587,7 @@ class ForeachPSGDLRA(C.BaseOpt):
582
587
  weight_decay=0.0,
583
588
  preconditioner_update_probability=None,
584
589
  momentum_into_precond_update=True,
585
- rank: int = 4,
590
+ rank: Optional[int] = None,
586
591
  warmup_steps: int = 0,
587
592
  foreach: bool = True,
588
593
  q_dtype="float32",
@@ -608,6 +613,14 @@ class ForeachPSGDLRA(C.BaseOpt):
608
613
  )
609
614
  params = defaults.pop("params")
610
615
 
616
+ if rank is None:
617
+ utils.warn_once(
618
+ f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
619
+ )
620
+ params = list(params)
621
+ defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
622
+ utils.warn_once(f"rank was set to {defaults['rank']}")
623
+
611
624
  delayed = C.default(delayed, self.delayed)
612
625
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
613
626
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
@@ -632,6 +645,10 @@ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
632
645
  hessian_approx = True
633
646
 
634
647
 
648
+ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
649
+ hvp_interval = 2
650
+
651
+
635
652
  PalmForEachSoap = PaLMForeachSOAP
636
653
  PaLMSOAP = PaLMForeachSOAP
637
654
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -696,4 +713,6 @@ __all__ = [
696
713
  "DelayedPSGD",
697
714
  "PSGDLRA",
698
715
  "NewtonPSGDLRA",
716
+ "NewtonHybrid2PSGDLRA",
717
+ "NewtonHybrid2PSGDKron",
699
718
  ]
heavyball/chainable.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  import random
3
4
  from typing import List, Literal, Optional, Union
4
5
 
@@ -43,7 +44,7 @@ class FunctionTransform:
43
44
  raise NotImplementedError
44
45
 
45
46
  def get_fn(self):
46
- if hasattr(self.fn, "get_fn"):
47
+ if utils.hasattr_none(self.fn, "get_fn"):
47
48
  return self.fn.get_fn()
48
49
  return self.fn
49
50
 
@@ -426,7 +427,7 @@ def _store_std(state, group, update, grad, param):
426
427
  state["init_std"] = torch.std(grad, dim=0)
427
428
 
428
429
 
429
- @general_guard("init_std", init_fn=_store_std)
430
+ @general_guard("init_std", init_fn=_store_std, skip_first=False)
430
431
  @no_state
431
432
  def mup_approx(group, updates, grads, params, init_std):
432
433
  _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
@@ -435,6 +436,40 @@ def mup_approx(group, updates, grads, params, init_std):
435
436
  return updates
436
437
 
437
438
 
439
+ def _init_delta(state, group, update, grad, param, log_space: bool):
440
+ val = group["initial_d"]
441
+ state["delta"] = torch.full((), math.log(val) if log_space else val, dtype=param.dtype, device=param.device)
442
+
443
+
444
+ def _init_full_delta(state, group, update, grad, param, log_space: bool):
445
+ val = group["initial_d"]
446
+ state["delta"] = torch.full_like(param, math.log(val) if log_space else val)
447
+
448
+
449
+ @zero_guard("state")
450
+ @general_guard("delta", init_fn=functools.partial(_init_delta, log_space=False), skip_first=False)
451
+ @no_state
452
+ def scale_by_d_adaptation(group, update, grad, param, state, delta):
453
+ utils.d_adaptation(grad, update, state, delta)
454
+ return update
455
+
456
+
457
+ @zero_guard("state")
458
+ @general_guard("delta", init_fn=functools.partial(_init_delta, log_space=True), skip_first=False)
459
+ @no_state
460
+ def scale_by_lr_adaptation(group, update, grad, param, state, delta):
461
+ utils.lr_adaptation(grad, update, state, delta, group["lr_lr"])
462
+ return update
463
+
464
+
465
+ @zero_guard("state")
466
+ @general_guard("delta", init_fn=functools.partial(_init_full_delta, log_space=True), skip_first=False)
467
+ @no_state
468
+ def scale_by_pointwise_lr_adaptation(group, update, grad, param, state, delta):
469
+ utils.pointwise_lr_adaptation(grad, update, state, delta, group["lr_lr"])
470
+ return update
471
+
472
+
438
473
  @zero_guard("momentum")
439
474
  @no_state
440
475
  def heavyball_momentum(group, updates, grads, params, momentum):
@@ -484,18 +519,22 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
484
519
  if not group["is_preconditioning"]:
485
520
  return Q_mat
486
521
 
522
+ if utils.hasattr_none(param, "vector"):
523
+ vector, hessian_vector = param.vector, param.hessian_vector
524
+ del param.vector
525
+ del param.hessian_vector
526
+ else:
527
+ vector, hessian_vector = utils.dampen_grad(grad)
528
+
487
529
  utils.psgd_update_precond(
488
530
  Q_mat,
489
531
  exprs,
490
- getattr(param, "hessian_vector", grad),
532
+ hessian_vector,
491
533
  group["precond_lr"],
492
534
  Q,
493
535
  group["store_triu_as_line"],
494
- getattr(param, "vector", None),
536
+ vector,
495
537
  )
496
- if hasattr(param, "vector"):
497
- del param.vector
498
- del param.hessian_vector
499
538
 
500
539
  if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
501
540
  if group["store_triu_as_line"]:
@@ -566,9 +605,12 @@ def _update_lra(
566
605
  if not group["is_preconditioning"]:
567
606
  return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
568
607
 
569
- if hasattr(params[0], "hessian_vector") and params[0].hessian_vector is not None:
608
+ if utils.hasattr_none(params[0], "hessian_vector"):
570
609
  vector = utils.flatten([p.vector for p in params])
571
610
  hessian_vector = utils.flatten([p.hessian_vector for p in params])
611
+ for p in params:
612
+ del p.vector
613
+ del p.hessian_vector
572
614
  else:
573
615
  vector, hessian_vector = utils.dampen_multiple(grads)
574
616
  return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
heavyball/utils.py CHANGED
@@ -1,11 +1,13 @@
1
+ import contextlib
1
2
  import functools
2
3
  import gc
4
+ import inspect
3
5
  import math
4
6
  import random
7
+ import re
5
8
  import string
6
9
  import warnings
7
10
  from typing import Callable, List, Optional, Tuple, Union
8
- from unittest.mock import patch
9
11
 
10
12
  import numpy as np
11
13
  import torch
@@ -15,13 +17,22 @@ from torch._dynamo.exc import TorchDynamoException
15
17
  from torch.backends import cudnn, opt_einsum
16
18
  from torch.utils._pytree import tree_map
17
19
 
18
- config.cache_size_limit = 2 ** 16
20
+ config.cache_size_limit = 2**16
19
21
 
20
22
  compile_mode = "max-autotune-no-cudagraphs"
21
23
  dynamic = False
22
24
  compile_mode_recommended_to_none = None
23
25
  zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
24
26
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
+ _cudnn_double_backward_pattern = re.compile(
28
+ r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
29
+ )
30
+ _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
31
+ _fd_error = (
32
+ "You can accelerate startup by globally enabling finite_differences first " #
33
+ "(via opt.finite_differences=True or by subclassing it)\n"
34
+ "Original Error: "
35
+ )
25
36
 
26
37
 
27
38
  def decorator(func):
@@ -39,7 +50,7 @@ def decorator(func):
39
50
  return _fn
40
51
 
41
52
 
42
- def decorator_knowngood(func: Callable):
53
+ def decorator_knowngood(func: Callable, fullgraph: bool = True):
43
54
  compiled = None
44
55
 
45
56
  @functools.wraps(func)
@@ -48,7 +59,7 @@ def decorator_knowngood(func: Callable):
48
59
  return func(*args, **kwargs)
49
60
  nonlocal compiled
50
61
  if compiled is None:
51
- compiled = torch.compile(fullgraph=True, dynamic=dynamic, mode=compile_mode)(func)
62
+ compiled = torch.compile(fullgraph=fullgraph, dynamic=dynamic, mode=compile_mode)(func)
52
63
  return compiled(*args, **kwargs)
53
64
 
54
65
  return _fn
@@ -58,8 +69,17 @@ einsum_base = string.ascii_lowercase
58
69
 
59
70
 
60
71
  @decorator_knowngood
61
- def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
62
- beta1: Tensor, decay: float, grad: List[Tensor], caution, ):
72
+ def _compilable_schedule_free_(
73
+ p: List[Tensor],
74
+ z: List[Tensor],
75
+ ckp1: Tensor,
76
+ update: List[Tensor],
77
+ lr: Tensor,
78
+ beta1: Tensor,
79
+ decay: float,
80
+ grad: List[Tensor],
81
+ caution,
82
+ ):
63
83
  for op, oz, u_, g_ in zip(p, z, update, grad):
64
84
  u_ = u_.view_as(op)
65
85
  p_, z_, u_ = map(promote, (op, oz, u_))
@@ -74,9 +94,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
74
94
  copy_stochastic_(oz, z_)
75
95
 
76
96
 
77
- def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
78
- z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0, step: int = 0,
79
- decay: float = 0.0, ):
97
+ def schedule_free_(
98
+ lr: float,
99
+ weight_lr_power: float,
100
+ weight_sum: float,
101
+ beta1: float,
102
+ parameters: List[Tensor],
103
+ z: List[Tensor],
104
+ update: List[Tensor],
105
+ grad: List[Tensor],
106
+ caution: bool = False,
107
+ r: float = 0.0,
108
+ step: int = 0,
109
+ decay: float = 0.0,
110
+ ):
80
111
  weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
81
112
  weight_sum = weight_sum + weight
82
113
 
@@ -149,7 +180,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
149
180
 
150
181
 
151
182
  def beta_debias(beta, step):
152
- return 1 - (1 - beta) / (1 - beta ** step)
183
+ return 1 - (1 - beta) / (1 - beta**step)
153
184
 
154
185
 
155
186
  def eps_sqrt(item, eps):
@@ -157,8 +188,9 @@ def eps_sqrt(item, eps):
157
188
 
158
189
 
159
190
  @decorator_knowngood
160
- def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
161
- out: List[Optional[Tensor]]):
191
+ def _compilable_exp_avg_sq_(
192
+ state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
193
+ ):
162
194
  g32 = promote(grad)
163
195
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
164
196
 
@@ -219,8 +251,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
219
251
  copy_stochastic_list_(gradients, g32)
220
252
 
221
253
 
222
- def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
223
- minimum: float = 1e-3, eps: float = 1e-8):
254
+ def adaptive_gradient_clipping_(
255
+ parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8
256
+ ):
224
257
  if clip_val <= 0:
225
258
  return gradients
226
259
  parameters, gradients = list_guard(parameters, gradients)
@@ -259,9 +292,11 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
259
292
 
260
293
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
261
294
  _ignore_warning(
262
- "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak")
295
+ "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
296
+ )
263
297
  _ignore_warning(
264
- "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak")
298
+ "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
299
+ )
265
300
 
266
301
 
267
302
  @decorator
@@ -408,7 +443,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
408
443
 
409
444
  assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
410
445
  in_str = einsum_base[: exp_avg.dim()]
411
- out_str = einsum_base[exp_avg.dim(): 2 * exp_avg.dim()]
446
+ out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()]
412
447
 
413
448
  from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
414
449
  if not from_shampoo:
@@ -418,8 +453,9 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
418
453
  out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
419
454
 
420
455
  subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
421
- exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
422
- *[q for q in new_qs if q is not None])
456
+ exp_avg_new = torch.einsum(
457
+ subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
458
+ )
423
459
  copy_stochastic_(exp_avg, exp_avg_new)
424
460
 
425
461
  for q, q_new in zip(Q, new_qs):
@@ -546,6 +582,20 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
546
582
  _compilable_stochastic_add_(x, y, alpha)
547
583
 
548
584
 
585
+ @decorator_knowngood
586
+ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Tensor, divisor: Tensor):
587
+ for x_, y_ in zip(x, y):
588
+ x32 = promote(x_)
589
+ y32 = promote(y_)
590
+ copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
591
+
592
+
593
+ def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
594
+ x, y = list_guard(x, y)
595
+ alpha, divisor = scalar_guard(alpha, divisor, x[0])
596
+ _compilable_stochastic_add_divide_(x, y, alpha, divisor)
597
+
598
+
549
599
  @decorator_knowngood
550
600
  def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
551
601
  for x_, y_ in zip(x, y):
@@ -594,6 +644,20 @@ def promote(x):
594
644
  return x
595
645
 
596
646
 
647
+ def promote_detach(x, should_promote):
648
+ if x is None:
649
+ return x
650
+ if should_promote:
651
+ x = promote(x)
652
+ return x.detach()
653
+
654
+
655
+ def detach(x):
656
+ if isinstance(x, Tensor):
657
+ return x.detach()
658
+ return x
659
+
660
+
597
661
  def min_dtype(xs: List[Tensor]):
598
662
  dtypes = [x.dtype for x in xs]
599
663
  for d in (torch.float32, torch.bfloat16, torch.float16):
@@ -647,25 +711,36 @@ def project(grad, Q, back: bool):
647
711
  return grad
648
712
 
649
713
 
650
- def modify_closure(closure):
651
- """
652
- Modifies the closure function to use create_graph=True in backward().
714
+ @contextlib.contextmanager
715
+ def patch_backward():
716
+ @contextlib.contextmanager
717
+ def _inner(module):
718
+ original = module.backward
653
719
 
654
- Args:
655
- closure: The closure function passed to the optimizer.
720
+ signature = inspect.signature(original)
656
721
 
657
- Returns:
658
- The return value of the modified closure.
659
- """
722
+ def patched_backward(*args, **kwargs):
723
+ new_kwargs = signature.bind(*args)
724
+ new_kwargs.apply_defaults()
725
+ new_kwargs = new_kwargs.arguments
726
+ new_kwargs.update(kwargs)
727
+ new_kwargs["create_graph"] = True
728
+ return original(**new_kwargs)
729
+
730
+ module.backward = patched_backward
731
+ yield
732
+ module.backward = original
733
+
734
+ with _inner(torch.Tensor), _inner(torch.autograd):
735
+ yield
660
736
 
661
- def patched_backward(self, *args, **kwargs):
662
- kwargs["create_graph"] = True
663
- return original_backward(self, *args, **kwargs)
664
737
 
665
- original_backward = torch.Tensor.backward
738
+ def hasattr_none(obj, name):
739
+ return getattr(obj, name, None) is not None
666
740
 
667
- with patch.object(torch.Tensor, "backward", patched_backward):
668
- return closure()
741
+
742
+ class ExactHVPFailed(ValueError):
743
+ pass
669
744
 
670
745
 
671
746
  class StatefulOptimizer(torch.optim.Optimizer):
@@ -682,6 +757,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
682
757
  precond_schedule: Union[Callable, float, None] = None
683
758
  stochastic_schedule: bool = False
684
759
  finite_differences: bool = False
760
+ fallback_to_finite_differences: bool = True
761
+ _fallback_enabled: bool = False
762
+ hvp_interval: int = 1 # grad is faster initially, hvp later
685
763
 
686
764
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
687
765
  super().__init__(params, {**defaults, "foreach": foreach})
@@ -708,29 +786,46 @@ class StatefulOptimizer(torch.optim.Optimizer):
708
786
  old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
709
787
  mars_correction(g_list, old_gs, mars_gamma, beta)
710
788
 
711
- def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
712
- beta1: float = -1.0):
789
+ def split_p_and_g_in_group(
790
+ self,
791
+ group: dict,
792
+ skip_none: bool = True,
793
+ should_promote: bool = True,
794
+ beta1: float = -1.0,
795
+ raw: bool = False,
796
+ ):
713
797
  for p in group["params"]:
798
+ grad = getattr(p, "grad", None)
799
+ if grad is None and skip_none:
800
+ continue
801
+
802
+ p.grad = None
803
+
804
+ if raw:
805
+ yield p, grad
806
+ continue
807
+
714
808
  if p in self.mapping:
715
809
  p_views = self.mapping[p]
716
810
  else:
717
811
  self.mapping[p] = p_views = merge_group(group, p)
718
812
 
719
- grad = getattr(p, "grad", None)
720
- p.grad = None
813
+ vector = getattr(p, "vector", None)
814
+ hessian_vector = getattr(p, "hessian_vector", None)
815
+ p.vector = None
816
+ p.hessian_vector = None
721
817
 
722
- if grad is None:
723
- grad = [getattr(pv, "grad", None) for pv in p_views]
724
- else:
725
- grad = merge_group(group, grad)
818
+ grad, vs, hvs = [
819
+ [None] * len(p_views) if x is None else merge_group(group, x) #
820
+ for x in (grad, vector, hessian_vector)
821
+ ]
726
822
 
727
- for pv, g in zip(p_views, grad):
728
- if skip_none and g is None:
729
- continue
730
- if should_promote:
731
- g = promote(g)
823
+ for pv, g, v, hv in zip(p_views, grad, vs, hvs):
824
+ g = promote_detach(g, should_promote)
732
825
  if beta1 >= 0 and group.get("mars", False):
733
826
  self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
827
+ pv.vector = promote_detach(v, should_promote)
828
+ pv.hessian_vector = promote_detach(hv, should_promote)
734
829
  yield pv, g
735
830
 
736
831
  def state_size(self) -> int:
@@ -794,6 +889,66 @@ class StatefulOptimizer(torch.optim.Optimizer):
794
889
  set_(self.state_(p)["param_ema"], p.data)
795
890
  set_(p.data, ema_clone)
796
891
 
892
+ def _finite_differences_hvp(self, closure):
893
+ with torch.enable_grad():
894
+ loss = closure() # closure without retain_graph=True
895
+
896
+ grads = []
897
+ for group in self.param_groups:
898
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
899
+ grads.append(g)
900
+ p.vector = torch.randn_like(p)
901
+ p.orig = p.data.clone()
902
+ # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
903
+ stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
904
+
905
+ with torch.enable_grad():
906
+ closure()
907
+
908
+ # we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
909
+ # this costs more memory, but the imprecision seems too severe to use the other method
910
+ for group in self.param_groups:
911
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
912
+ p.grad = grads.pop(0)
913
+ stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
914
+ p.hessian_vector = g
915
+ p.data.copy_(p.orig)
916
+ del p.orig
917
+ return loss
918
+
919
+ def _double_backward_hvp(self, closure):
920
+ with torch.enable_grad(), patch_backward():
921
+ loss = closure()
922
+
923
+ params, grads = [], []
924
+ for group in self.param_groups:
925
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
926
+ params.append(p)
927
+ grads.append(g)
928
+
929
+ if not params:
930
+ raise ValueError("No parameter has gradients")
931
+
932
+ vs = [torch.randn_like(p) for p in params]
933
+ with torch.enable_grad():
934
+ try:
935
+ hvs = torch.autograd.grad(grads, params, vs, create_graph=False, retain_graph=False, allow_unused=True)
936
+ except RuntimeError as e:
937
+ raise ExactHVPFailed(str(e.args))
938
+
939
+ unused = []
940
+ for p, g, v, hv in zip(params, grads, vs, hvs):
941
+ p.hessian_vector = detach(hv)
942
+ p.grad = detach(g)
943
+ p.vector = detach(v)
944
+ if hv is None:
945
+ unused.append(list(p.shape))
946
+
947
+ if unused:
948
+ raise ExactHVPFailed(f"Parameters with the following shapes have no 2nd order derivative: {unused}")
949
+
950
+ return loss
951
+
797
952
  def _handle_closure(self, closure):
798
953
  hessian_approx = self.hessian_approx and self._is_preconditioning
799
954
 
@@ -802,56 +957,41 @@ class StatefulOptimizer(torch.optim.Optimizer):
802
957
  raise ValueError("Hessian approximation requires a closure.")
803
958
  return None
804
959
 
805
- if not hessian_approx:
960
+ step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
961
+ if not hessian_approx or step % self.hvp_interval == 0:
806
962
  with torch.enable_grad():
807
963
  loss = closure()
808
964
  return loss
809
965
 
810
- if self.finite_differences:
811
- with torch.enable_grad():
812
- loss = closure() # closure without retain_graph=True
966
+ if self.finite_differences or self._fallback_enabled:
967
+ return self._finite_differences_hvp(closure)
813
968
 
814
- grads = []
815
- for group in self.param_groups:
816
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
817
- grads.append(g)
818
- p.vector = torch.randn_like(p)
819
- p.orig = p.data.clone()
820
- # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
821
- stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
822
- else:
823
- with torch.enable_grad():
824
- loss = modify_closure(closure)
825
-
826
- if self.finite_differences:
827
- with torch.enable_grad():
828
- closure()
829
-
830
- # we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
831
- # this costs more memory, but the imprecision seems too severe to use the other method
832
- for group in self.param_groups:
833
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
834
- p.grad = grads.pop(0)
835
- stochastic_add_(g, p.grad, -1)
836
- p.hessian_vector = g
837
- p.data.copy_(p.orig)
838
- del p.orig
839
- else:
840
- for group in self.param_groups:
841
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
842
- p.grad = g
843
- params, grads = zip(*[x for group in self.param_groups for x in
844
- self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
845
- vs = [torch.randn_like(p) for p in params]
846
- with torch.enable_grad():
847
- hvs = torch.autograd.grad(grads, params, vs)
848
-
849
- for p, g, v, hv in zip(params, grads, vs, hvs):
850
- p.hessian_vector = hv
851
- p.grad = g
852
- p.vector = v
853
-
854
- return loss
969
+ try:
970
+ return self._double_backward_hvp(closure)
971
+ except NotImplementedError as e:
972
+ if not self.fallback_to_finite_differences:
973
+ raise
974
+ if not any(isinstance(arg, str) and _cudnn_double_backward_pattern.match(arg) for arg in e.args):
975
+ raise
976
+ warn_once(
977
+ "CUDNN doesn't support double-backward for some models (including RNNs). " #
978
+ f"Falling back to finite_differences.\n{_fd_error}{e}"
979
+ )
980
+ except RuntimeError as e:
981
+ if not self.fallback_to_finite_differences:
982
+ raise
983
+ if not any(isinstance(arg, str) and _torch_compile_double_backward_pattern.match(arg) for arg in e.args):
984
+ raise
985
+ warn_once(
986
+ f"torch.compile does not support double-backward. Disabling it may be beneficial, depending on "
987
+ f"the model.\n{_fd_error}{e}"
988
+ )
989
+ except ExactHVPFailed as e:
990
+ if not self.fallback_to_finite_differences:
991
+ raise
992
+ warn_once(f"Exact HVP calculation failed.\n{_fd_error}{e}")
993
+ self._fallback_enabled = True
994
+ return self._handle_closure(closure)
855
995
 
856
996
  def step(self, closure: Optional[Callable] = None):
857
997
  if self.precond_schedule is None:
@@ -867,7 +1007,11 @@ class StatefulOptimizer(torch.optim.Optimizer):
867
1007
  self._step(group)
868
1008
  if self.use_ema:
869
1009
  self.ema_update()
870
-
1010
+ for real, views in self.mapping.items():
1011
+ for tensor in (real, *views):
1012
+ for key in ("grad", "vector", "hessian_vector", "orig"):
1013
+ if hasattr(tensor, key):
1014
+ setattr(tensor, key, None)
871
1015
  return loss
872
1016
 
873
1017
 
@@ -887,8 +1031,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
887
1031
 
888
1032
 
889
1033
  @decorator_knowngood
890
- def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
891
- step: Tensor, eps: Tensor, ):
1034
+ def _compilable_adam_(
1035
+ exp_avg: List[Tensor],
1036
+ exp_avg_sq: List[Tensor],
1037
+ grad: List[Tensor],
1038
+ beta1: Tensor,
1039
+ beta2: Tensor,
1040
+ step: Tensor,
1041
+ eps: Tensor,
1042
+ ):
892
1043
  beta1 = beta_debias(beta1, step)
893
1044
  beta2 = beta_debias(beta2, step)
894
1045
 
@@ -899,8 +1050,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
899
1050
  copy_stochastic_list_(grad, u32)
900
1051
 
901
1052
 
902
- def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
903
- eps: float = 1e-8, ):
1053
+ def adam_(
1054
+ exp_avg: List[Tensor],
1055
+ exp_avg_sq: List[Tensor],
1056
+ grad: List[Tensor],
1057
+ beta1: float,
1058
+ beta2: float,
1059
+ step: int,
1060
+ eps: float = 1e-8,
1061
+ ):
904
1062
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
905
1063
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
906
1064
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -908,9 +1066,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
908
1066
 
909
1067
 
910
1068
  @decorator_knowngood
911
- def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
912
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
913
- caution: bool, ):
1069
+ def _fused_compilable_adam_(
1070
+ y: List[Tensor],
1071
+ exp_avg: List[Tensor],
1072
+ exp_avg_sq: List[Tensor],
1073
+ update: List[Tensor],
1074
+ grad: List[Tensor],
1075
+ beta1: Tensor,
1076
+ beta2: Tensor,
1077
+ step: Tensor,
1078
+ decay: Tensor,
1079
+ lr: Tensor,
1080
+ eps: Tensor,
1081
+ caution: bool,
1082
+ ):
914
1083
  beta1 = beta_debias(beta1, step)
915
1084
  beta2 = beta_debias(beta2, step)
916
1085
 
@@ -921,17 +1090,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
921
1090
  _compilable_update_(y, u32, decay, lr, caution, g32)
922
1091
 
923
1092
 
924
- def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
925
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
926
- caution: bool, ):
1093
+ def fused_adam_(
1094
+ y: List[Tensor],
1095
+ exp_avg: List[Tensor],
1096
+ exp_avg_sq: List[Tensor],
1097
+ update: List[Tensor],
1098
+ grad: List[Tensor],
1099
+ beta1: float,
1100
+ beta2: float,
1101
+ step: int,
1102
+ lr: float,
1103
+ eps: float,
1104
+ decay: float,
1105
+ caution: bool,
1106
+ ):
927
1107
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
928
1108
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
929
1109
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
930
1110
 
931
1111
 
932
1112
  @decorator_knowngood
933
- def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
934
- beta2: Tensor, step: Tensor, eps: Tensor, ):
1113
+ def _compilable_laprop_(
1114
+ exp_avg: List[Tensor],
1115
+ exp_avg_sq: List[Tensor],
1116
+ grad: List[Tensor],
1117
+ beta1: Tensor,
1118
+ beta2: Tensor,
1119
+ step: Tensor,
1120
+ eps: Tensor,
1121
+ ):
935
1122
  beta1 = beta_debias(beta1, step)
936
1123
  beta2 = beta_debias(beta2, step)
937
1124
 
@@ -942,8 +1129,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
942
1129
  copy_stochastic_list_(grad, gp32)
943
1130
 
944
1131
 
945
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
946
- eps: float = 1e-8, ):
1132
+ def laprop_(
1133
+ exp_avg: List[Tensor],
1134
+ exp_avg_sq: List[Tensor],
1135
+ grad: List[Tensor],
1136
+ beta1: float,
1137
+ beta2: float,
1138
+ step: int,
1139
+ eps: float = 1e-8,
1140
+ ):
947
1141
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
948
1142
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
949
1143
  _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -951,9 +1145,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
951
1145
 
952
1146
 
953
1147
  @decorator_knowngood
954
- def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
955
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor, caution: bool,
956
- eps: Tensor, ):
1148
+ def _fused_compilable_laprop_(
1149
+ y: List[Tensor],
1150
+ exp_avg: List[Tensor],
1151
+ exp_avg_sq: List[Tensor],
1152
+ update: List[Tensor],
1153
+ grad: List[Tensor],
1154
+ beta1: Tensor,
1155
+ beta2: Tensor,
1156
+ step: Tensor,
1157
+ lr: Tensor,
1158
+ decay: Tensor,
1159
+ caution: bool,
1160
+ eps: Tensor,
1161
+ ):
957
1162
  beta1 = beta_debias(beta1, step)
958
1163
  beta2 = beta_debias(beta2, step)
959
1164
 
@@ -964,9 +1169,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
964
1169
  _compilable_update_(y, u32, decay, lr, caution, gp32)
965
1170
 
966
1171
 
967
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
968
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
969
- eps: float = 1e-8, ):
1172
+ def fused_laprop_(
1173
+ y: List[Tensor],
1174
+ exp_avg: List[Tensor],
1175
+ exp_avg_sq: List[Tensor],
1176
+ update: List[Tensor],
1177
+ grad: List[Tensor],
1178
+ beta1: float,
1179
+ beta2: float,
1180
+ step: int,
1181
+ lr: float,
1182
+ decay: float,
1183
+ caution: bool,
1184
+ eps: float = 1e-8,
1185
+ ):
970
1186
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
971
1187
  beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
972
1188
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
@@ -1040,8 +1256,9 @@ def copy_stochastic_(target: Tensor, source: Tensor):
1040
1256
 
1041
1257
 
1042
1258
  @decorator_knowngood
1043
- def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
1044
- g: List[Optional[Tensor]]):
1259
+ def _compilable_update_(
1260
+ p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]]
1261
+ ):
1045
1262
  for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
1046
1263
  u_ = promote(u_.view_as(p_))
1047
1264
  p32_ = promote(p_)
@@ -1051,8 +1268,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
1051
1268
  copy_stochastic_(p_, p32_)
1052
1269
 
1053
1270
 
1054
- def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
1055
- grad: List[Tensor] = None):
1271
+ def update_param_(
1272
+ param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None
1273
+ ):
1056
1274
  param, update, grad = list_guard(param, update, grad)
1057
1275
  lr = scalar_guard(lr, param[0])
1058
1276
  if not caution:
@@ -1076,28 +1294,83 @@ def _max_idx(x: List[int]):
1076
1294
 
1077
1295
 
1078
1296
  @decorator_knowngood
1079
- def mean_root(x: torch.Tensor, pow: float):
1080
- return stochastic_round_(x, x.float().pow(pow).mean().pow(-1 / pow / 2))
1297
+ def stable_exp(x: Tensor):
1298
+ # fp16:
1299
+ # exp(x) is stable in [-17, 11]
1300
+ # `stable_exp` extends to [-17, 17]
1301
+ # average error (in [-10, 10]) increased from 2.288e-3 to 2.299e-3
1302
+ # fp32:
1303
+ # exp(x) is stable in [-103, 88]
1304
+ # `stable_exp` extends to [-103, 103]
1305
+ # average error (in [-87, 87]) reduced from 3.309-06 to 3.224-06
1306
+ return torch.where(x > 0, 1 / (-x).exp(), x.exp())
1307
+
1308
+
1309
+ @decorator_knowngood
1310
+ def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
1311
+ # 1 / (mean(x ** pow) ** (1 / pow / 2))
1312
+ log_x = x.double().abs().clamp(min=eps).log()
1313
+ log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
1314
+ return stable_exp(-log_mean_x_pow / pow / 2)
1081
1315
 
1082
1316
 
1083
1317
  @decorator_knowngood
1084
- def divided_root(x, y, pow0, pow1):
1085
- mean_x = x.float().pow(pow0).mean().pow(1 / pow0 / 2)
1086
- mean_y = y.float().pow(pow1).mean().pow(-1 / pow1 / 2)
1087
- return stochastic_round_(x, mean_x * mean_y) # multiply here, as we already divide in pow -1
1318
+ def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
1319
+ # mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
1320
+ log_x = x.double().abs().clamp(min=eps).log()
1321
+ log_y = y.double().abs().clamp(min=eps).log()
1322
+
1323
+ x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
1324
+ x_normed = x_normed / pow0 / 2
1325
+
1326
+ y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
1327
+ y_normed = y_normed / pow1 / 2
1328
+
1329
+ return stable_exp(x_normed - y_normed)
1088
1330
 
1089
1331
 
1090
- def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector):
1332
+ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1333
+ automatic_scale = True
1334
+ manual_hint = " Set it manually using `precond_init_scale=0.1`"
1335
+
1091
1336
  if scale is not None:
1337
+ automatic_scale = False
1092
1338
  warn_once(
1093
- "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics.")
1094
- if scale_scale is not None:
1339
+ "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
1340
+ )
1341
+ if scale_scale is not None and scale_scale != 1:
1095
1342
  warn_once(
1096
- "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale.")
1343
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
1344
+ )
1345
+ elif hessian_vector is None:
1346
+ scale = mean_root(grad, 4) * scale_scale
1347
+ else:
1348
+ scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1349
+
1350
+ if isinstance(scale, torch.Tensor):
1351
+ scale = scale.item() # slow, but necessary
1352
+
1353
+ if np.isfinite(scale):
1354
+ if scale > scale_max or scale < 1 / scale_max: # fallthrough to later checks
1355
+ warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1356
+ else:
1357
+ return scale
1358
+
1359
+ if not automatic_scale:
1360
+ raise ValueError("The manually set precond_init_scale is not finite")
1361
+
1362
+ for x in (grad, hessian_vector, vector):
1363
+ if x is None:
1364
+ continue
1365
+ if torch.allclose(x, torch.zeros_like(x)).item():
1366
+ raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1367
+ if not torch.isfinite(x).all().item():
1368
+ raise ValueError("Grad or HVP is not finite")
1369
+
1370
+ if np.isfinite(scale):
1097
1371
  return scale
1098
- if hessian_vector is None:
1099
- return mean_root(grad, 4) * scale_scale
1100
- return divided_root(vector, hessian_vector, 2, 4) * scale_scale
1372
+
1373
+ raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1101
1374
 
1102
1375
 
1103
1376
  def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
@@ -1108,8 +1381,9 @@ def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None)
1108
1381
  return U, V, d
1109
1382
 
1110
1383
 
1111
- def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector,
1112
- dtype=None):
1384
+ def init_Q_exprs(
1385
+ grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
1386
+ ):
1113
1387
  """
1114
1388
  For a scalar or tensor `grad`, we initialize its preconditioner Q and
1115
1389
  reusable einsum expressions for updating Q and preconditioning gradient.
@@ -1147,8 +1421,10 @@ def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory
1147
1421
  elif memory_save_mode == "all_diag":
1148
1422
  dim_diag = [True for _ in shape]
1149
1423
  else:
1150
- raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1151
- "[None, 'one_diag', 'all_diag', 'smart_one_diag']")
1424
+ raise ValueError(
1425
+ f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1426
+ "[None, 'one_diag', 'all_diag', 'smart_one_diag']"
1427
+ )
1152
1428
 
1153
1429
  Q = []
1154
1430
  piece1A, piece2A, piece3A = ([], "", "")
@@ -1213,8 +1489,16 @@ def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1213
1489
  return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1214
1490
 
1215
1491
 
1216
- def update_lra_precond_(U: List[Tensor], V: List[Tensor], d: List[Tensor], vector: Tensor, hessian_vector: Tensor,
1217
- eps: float, step: float, delayed: bool, ):
1492
+ def update_lra_precond_(
1493
+ U: List[Tensor],
1494
+ V: List[Tensor],
1495
+ d: List[Tensor],
1496
+ vector: Tensor,
1497
+ hessian_vector: Tensor,
1498
+ eps: float,
1499
+ step: float,
1500
+ delayed: bool,
1501
+ ):
1218
1502
  """
1219
1503
  Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1220
1504
  """
@@ -1293,7 +1577,7 @@ def lra_precond(U, V, d, g):
1293
1577
 
1294
1578
 
1295
1579
  @decorator_knowngood
1296
- def dampen_grad(g: Tensor, damp: float = 2 ** -13):
1580
+ def dampen_grad(g: Tensor, damp: float = 2**-13):
1297
1581
  # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
1298
1582
  v = torch.randn_like(g)
1299
1583
  return v, g + damp * g.abs().mean() * v
@@ -1306,7 +1590,7 @@ def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor,
1306
1590
  update = update.flatten()
1307
1591
  for p in params:
1308
1592
  size = p.numel()
1309
- copy_stochastic_(p, update[start: start + size].view_as(p))
1593
+ copy_stochastic_(p, update[start : start + size].view_as(p))
1310
1594
  start += size
1311
1595
 
1312
1596
 
@@ -1316,7 +1600,7 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
1316
1600
  update = update.flatten()
1317
1601
  for p in params:
1318
1602
  size = p.numel()
1319
- copy_stochastic_(p, update[start: start + size].view_as(p))
1603
+ copy_stochastic_(p, update[start : start + size].view_as(p))
1320
1604
  start += size
1321
1605
 
1322
1606
 
@@ -1326,7 +1610,7 @@ def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1326
1610
  update = update.flatten()
1327
1611
  for p in params:
1328
1612
  size = p.numel()
1329
- stochastic_add_([p], [update[start: start + size].view_as(p)], alpha)
1613
+ stochastic_add_([p], [update[start : start + size].view_as(p)], alpha)
1330
1614
  start += size
1331
1615
 
1332
1616
 
@@ -1337,16 +1621,19 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1337
1621
  update = update.flatten()
1338
1622
  for p in params:
1339
1623
  size = p.numel()
1340
- outputs.append(update[start: start + size].view_as(p))
1624
+ outputs.append(update[start : start + size].view_as(p))
1341
1625
  start += size
1342
1626
  return outputs
1343
1627
 
1344
1628
 
1629
+ @decorator_knowngood
1345
1630
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1346
- return torch.cat([i.flatten(0, -1 - remaining) for i in x], 0)
1631
+ last_dim = x[0].shape[-remaining:] if remaining else []
1632
+ return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
1347
1633
 
1348
1634
 
1349
- def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
1635
+ @decorator_knowngood
1636
+ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1350
1637
  vs = []
1351
1638
  gs = []
1352
1639
  for g_ in g:
@@ -1356,30 +1643,58 @@ def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
1356
1643
  return flatten(vs), flatten(gs)
1357
1644
 
1358
1645
 
1359
- def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1360
- order = G.dim()
1361
- if V is None:
1362
- V, G = dampen_grad(G)
1363
- conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
1364
- md = min_dtype(Q + [G])
1365
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1366
- Q = [promote(q) for q in Q]
1367
- for i, q in enumerate(Q):
1646
+ def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1647
+ md = min_dtype(args)
1648
+ return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1649
+
1650
+
1651
+ @decorator_knowngood
1652
+ def _psgd_calc_scalars_(Qs: List[Tensor], conjB: Tensor):
1653
+ triangular_qs = []
1654
+ for i, q in enumerate(Qs):
1655
+ q = promote(q)
1368
1656
  if q.dim() <= 1:
1369
- conjB /= q
1657
+ shape = [1] * conjB.ndim
1658
+ shape[i] = -1
1659
+ conjB /= q.view(shape)
1370
1660
  else:
1371
- conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
1372
- conjB)
1373
- if i < order - 1:
1374
- conjB = torch.transpose(conjB, i, order - 1)
1661
+ triangular_qs.append((i, q))
1662
+ return triangular_qs
1663
+
1664
+
1665
+ @decorator_knowngood
1666
+ def _reshape_conjB(solved: Tensor, original_shape: List[int], last_dim: int, new_shape: int):
1667
+ solved = solved.reshape(original_shape)
1668
+ solved.transpose(last_dim, -1)
1669
+ return solved.reshape(new_shape).contiguous()
1670
+
1671
+
1672
+ def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1673
+ order = G.dim()
1674
+ if order > 1:
1675
+ conjB = conjB.view_as(G).permute(*range(1, order), 0)
1676
+ conjB = conjB.to(promote(G.dtype))
1677
+ A = casted_einsum(exprA, *Q, G)
1678
+ solve = torch.compiler.disable(torch.linalg.solve_triangular)
1679
+ original_shape = conjB.shape
1680
+ prev_i = -1
1681
+ for i, tri_q in _psgd_calc_scalars_(Q, conjB):
1682
+ conjB = _reshape_conjB(conjB, original_shape, prev_i, [-1, tri_q.size(0)])
1683
+ prev_i = i
1684
+ conjB = solve(tri_q, conjB, upper=True, left=False)
1685
+ conjB = _reshape_conjB(conjB, original_shape, prev_i, original_shape)
1375
1686
  return A, conjB
1376
1687
 
1377
1688
 
1378
- def psgd_lb(A, max_abs):
1689
+ @decorator_knowngood
1690
+ def _max_select(to_index: Tensor, to_argmax: Tensor):
1691
+ idx = to_argmax.argmax()
1692
+ return to_index.index_select(1, idx).flatten().contiguous()
1693
+
1694
+
1695
+ def psgd_lb(A: Tensor, max_abs: Tensor):
1379
1696
  A /= max_abs
1380
- a0 = torch.einsum("ij,ij->j", A, A)
1381
- i = torch.argmax(a0)
1382
- x = torch.index_select(A, 1, i).flatten().contiguous()
1697
+ x = _max_select(A, torch.einsum("ij,ij->j", A, A))
1383
1698
  x = torch.einsum("i,ij->j", x, A)
1384
1699
  x /= x.norm()
1385
1700
  x = torch.einsum("j,kj->k", x, A)
@@ -1388,28 +1703,52 @@ def psgd_lb(A, max_abs):
1388
1703
  return x
1389
1704
 
1390
1705
 
1706
+ @decorator_knowngood
1707
+ def _subtract_from_line_(state: Tensor, term: Tensor):
1708
+ stochastic_add_([state], [triu_to_line([term])[0][1]], -1)
1709
+
1710
+
1711
+ @decorator_knowngood
1712
+ def _prescale_term_(term1: Tensor, fac: Tensor, norm: Tensor, lower_bound: Tensor):
1713
+ out = term1.float().triu() * fac
1714
+ out = out / torch.where(norm > 0, lower_bound, norm).clamp(tiny_bf16)
1715
+ copy_stochastic_(term1, out)
1716
+
1717
+
1718
+ @decorator_knowngood
1719
+ def _compilable_stochastic_multiply_div_(x: Tensor, fac: Tensor, y: Tensor, z: Tensor):
1720
+ copy_stochastic_(x, promote(x) * promote(fac) * promote(y) / promote(z).clamp(min=tiny_bf16))
1721
+
1722
+
1723
+ @decorator_knowngood
1724
+ def _compilable_add_sub_(x: Tensor, y: Tensor):
1725
+ x = promote(x)
1726
+ y = promote(y)
1727
+ return x - y, x + y
1728
+
1729
+
1391
1730
  @decorator
1392
1731
  def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1393
1732
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1394
1733
  exprA, exprGs, _ = exprs
1395
1734
  A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1735
+ precond_lr = scalar_guard(precond_lr, G)
1396
1736
 
1397
1737
  for q, exprG, o in zip(Q, exprGs, oq):
1398
- term1 = promote(torch.einsum(exprG, A, A))
1399
- term2 = promote(torch.einsum(exprG, conjB, conjB))
1400
- term1, term2 = term1 - term2, term1 + term2
1401
- term1 *= precond_lr
1738
+ term1 = torch.einsum(exprG, A, A)
1739
+ term2 = torch.einsum(exprG, conjB, conjB)
1740
+ term1, term2 = _compilable_add_sub_(term1, term2)
1402
1741
  norm = term2.norm(float("inf"))
1403
1742
  if q.dim() < 2:
1404
- term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
1743
+ _compilable_stochastic_multiply_div_(term1, precond_lr, q, norm)
1405
1744
  else:
1406
- torch.triu(term1, out=term1)
1407
- term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1408
- term1 = torch.mm(term1, q.to(term1.dtype))
1745
+ lower_bound = psgd_lb(term2, norm)
1746
+ _prescale_term_(term1, precond_lr, lower_bound, norm)
1747
+ torch.mm(term1, q.to(term1.dtype), out=term1)
1409
1748
  if store_triu_as_line:
1410
- term1 = triu_to_line([term1])[0][1]
1411
- o = o[1]
1412
- stochastic_add_(o, term1, -1)
1749
+ _subtract_from_line_(q, term1)
1750
+ else:
1751
+ stochastic_add_(o, term1, -1)
1413
1752
 
1414
1753
 
1415
1754
  @decorator_knowngood
@@ -1619,8 +1958,9 @@ def warn_once(msg):
1619
1958
  _warned.add(msg)
1620
1959
 
1621
1960
 
1622
- def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1623
- name: str = "cumulative_prob"):
1961
+ def psgd_should_update(
1962
+ group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
1963
+ ):
1624
1964
  group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
1625
1965
  if not isinstance(prob, float):
1626
1966
  prob = prob(group[f"{name}_prob_step"])
@@ -1632,8 +1972,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1632
1972
 
1633
1973
 
1634
1974
  @decorator_knowngood
1635
- def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1636
- cast: bool = True):
1975
+ def precond_grad_cached_(
1976
+ expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
1977
+ ):
1637
1978
  if caution:
1638
1979
  ea = _compilable_cautioning(grad, ea)
1639
1980
  md = min_dtype(list(cached_q) + [ea])
@@ -1753,12 +2094,79 @@ def merge_group(group, *tensors):
1753
2094
 
1754
2095
  out = []
1755
2096
  for t in tensors:
1756
- append_or_extend(out,
1757
- dim_merger(t, group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
1758
- group.get("split", False), ), )
2097
+ append_or_extend(
2098
+ out,
2099
+ dim_merger(
2100
+ t,
2101
+ group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
2102
+ group.get("split", False),
2103
+ ),
2104
+ )
1759
2105
  return out
1760
2106
 
1761
2107
 
2108
+ @decorator_knowngood
2109
+ def _compilable_d_adapt_(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
2110
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2111
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2112
+ next_d = d * (g * s).sum()
2113
+ s = s + u * d
2114
+ next_d = next_d / s.abs().sum()
2115
+ next_d = torch.maximum(next_d, d)
2116
+ copy_stochastic_(u_, u * d)
2117
+ copy_stochastic_(d_, next_d)
2118
+ copy_stochastic_(s_, s)
2119
+
2120
+
2121
+ def d_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
2122
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2123
+ _compilable_d_adapt_(grads, update, state, delta)
2124
+
2125
+
2126
+ @decorator_knowngood
2127
+ def _compilable_lr_adapt_(
2128
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
2129
+ ):
2130
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2131
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2132
+ lr_grad = d.sigmoid()
2133
+ lr_grad = lr_grad * (1 - lr_grad)
2134
+ lr_grad = lr_grad * (s * g).mean()
2135
+ d = d - lr_grad * lr_lr
2136
+ copy_stochastic_(d_, d)
2137
+ copy_stochastic_(u_, u * d.sigmoid())
2138
+ copy_stochastic_(s_, u)
2139
+
2140
+
2141
+ def lr_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float):
2142
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2143
+ lr_lr = scalar_guard(lr_lr, grads[0])
2144
+ _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
2145
+
2146
+
2147
+ @decorator_knowngood
2148
+ def _compilable_pointwise_lr_adapt_(
2149
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
2150
+ ):
2151
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2152
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2153
+ lr_grad = d.sigmoid()
2154
+ lr_grad = lr_grad * (1 - lr_grad)
2155
+ lr_grad = lr_grad * s * g
2156
+ d = d - lr_grad * lr_lr
2157
+ copy_stochastic_(d_, d)
2158
+ copy_stochastic_(u_, u * d.sigmoid())
2159
+ copy_stochastic_(s_, u)
2160
+
2161
+
2162
+ def pointwise_lr_adaptation(
2163
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float
2164
+ ):
2165
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2166
+ lr_lr = scalar_guard(lr_lr, grads[0])
2167
+ _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
2168
+
2169
+
1762
2170
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1763
2171
  optimizers = {}
1764
2172
 
@@ -1781,8 +2189,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1781
2189
 
1782
2190
  o = optimizer(parameters, *args, **kwargs)
1783
2191
  step_fn = o.step
1784
- o.step = functools.partial(warn_once,
1785
- msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
2192
+ o.step = functools.partial(
2193
+ warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything."
2194
+ )
1786
2195
 
1787
2196
  def _step(p: Tensor):
1788
2197
  seen_params.add(p)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 1.7.0
3
+ Version: 1.7.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=tt0QMvIbU6IRDexpjSWmWdNEVfYvsPT6-hAWfKrbDQc,20379
2
+ heavyball/chainable.py,sha256=jkiTzaXFjEMJztN3TRGkBV7s0-deCakmR1QGIZHb54o,32635
3
+ heavyball/utils.py,sha256=Y7YkYQhyUEZFUcTPQv6hrAL1gPE9oSydkuIEW5_LxbY,73545
4
+ heavyball-1.7.2.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
5
+ heavyball-1.7.2.dist-info/METADATA,sha256=iUY20QhT8d6hnb1udkOUnQyfRN_r8MM3Vhb0aq5eGNI,43718
6
+ heavyball-1.7.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
7
+ heavyball-1.7.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.7.2.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=64gbqGEWM0zxfXDCOAcB1VtIPd3sdzAOSNHBCzSg8uQ,19762
2
- heavyball/chainable.py,sha256=XCsBgBZtmd4swQSCtMmEpQtpsPbiJc18RAvaW9rlkIs,31174
3
- heavyball/utils.py,sha256=Uj3L-x5a56_G3G_VqqOrU7y098lxjkdjIwkKA7L5ETQ,62759
4
- heavyball-1.7.0.dist-info/licenses/LICENSE,sha256=G9fFZcNIVWjU7o6Pr_4sJBRCNDU5X-zelSxIJ2D48ms,1323
5
- heavyball-1.7.0.dist-info/METADATA,sha256=a8Aar_g95j_wZNL59vYc0BkIHwn49_RjDtflKON-HmQ,43718
6
- heavyball-1.7.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
7
- heavyball-1.7.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.7.0.dist-info/RECORD,,