heavyball 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,11 +1,13 @@
1
+ import contextlib
1
2
  import functools
2
3
  import gc
4
+ import inspect
3
5
  import math
4
6
  import random
7
+ import re
5
8
  import string
6
9
  import warnings
7
10
  from typing import Callable, List, Optional, Tuple, Union
8
- from unittest.mock import patch
9
11
 
10
12
  import numpy as np
11
13
  import torch
@@ -15,13 +17,22 @@ from torch._dynamo.exc import TorchDynamoException
15
17
  from torch.backends import cudnn, opt_einsum
16
18
  from torch.utils._pytree import tree_map
17
19
 
18
- config.cache_size_limit = 2 ** 16
20
+ config.cache_size_limit = 2**16
19
21
 
20
22
  compile_mode = "max-autotune-no-cudagraphs"
21
23
  dynamic = False
22
24
  compile_mode_recommended_to_none = None
23
25
  zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
24
26
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
+ _cudnn_double_backward_pattern = re.compile(
28
+ r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
29
+ )
30
+ _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
31
+ _fd_error = (
32
+ "You can accelerate startup by globally enabling finite_differences first " #
33
+ "(via opt.finite_differences=True or by subclassing it)\n"
34
+ "Original Error: "
35
+ )
25
36
 
26
37
 
27
38
  def decorator(func):
@@ -58,8 +69,17 @@ einsum_base = string.ascii_lowercase
58
69
 
59
70
 
60
71
  @decorator_knowngood
61
- def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
62
- beta1: Tensor, decay: float, grad: List[Tensor], caution, ):
72
+ def _compilable_schedule_free_(
73
+ p: List[Tensor],
74
+ z: List[Tensor],
75
+ ckp1: Tensor,
76
+ update: List[Tensor],
77
+ lr: Tensor,
78
+ beta1: Tensor,
79
+ decay: float,
80
+ grad: List[Tensor],
81
+ caution,
82
+ ):
63
83
  for op, oz, u_, g_ in zip(p, z, update, grad):
64
84
  u_ = u_.view_as(op)
65
85
  p_, z_, u_ = map(promote, (op, oz, u_))
@@ -74,9 +94,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
74
94
  copy_stochastic_(oz, z_)
75
95
 
76
96
 
77
- def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
78
- z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0, step: int = 0,
79
- decay: float = 0.0, ):
97
+ def schedule_free_(
98
+ lr: float,
99
+ weight_lr_power: float,
100
+ weight_sum: float,
101
+ beta1: float,
102
+ parameters: List[Tensor],
103
+ z: List[Tensor],
104
+ update: List[Tensor],
105
+ grad: List[Tensor],
106
+ caution: bool = False,
107
+ r: float = 0.0,
108
+ step: int = 0,
109
+ decay: float = 0.0,
110
+ ):
80
111
  weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
81
112
  weight_sum = weight_sum + weight
82
113
 
@@ -149,7 +180,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
149
180
 
150
181
 
151
182
  def beta_debias(beta, step):
152
- return 1 - (1 - beta) / (1 - beta ** step)
183
+ return 1 - (1 - beta) / (1 - beta**step)
153
184
 
154
185
 
155
186
  def eps_sqrt(item, eps):
@@ -157,8 +188,9 @@ def eps_sqrt(item, eps):
157
188
 
158
189
 
159
190
  @decorator_knowngood
160
- def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
161
- out: List[Optional[Tensor]]):
191
+ def _compilable_exp_avg_sq_(
192
+ state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
193
+ ):
162
194
  g32 = promote(grad)
163
195
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
164
196
 
@@ -219,8 +251,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
219
251
  copy_stochastic_list_(gradients, g32)
220
252
 
221
253
 
222
- def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
223
- minimum: float = 1e-3, eps: float = 1e-8):
254
+ def adaptive_gradient_clipping_(
255
+ parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8
256
+ ):
224
257
  if clip_val <= 0:
225
258
  return gradients
226
259
  parameters, gradients = list_guard(parameters, gradients)
@@ -259,9 +292,11 @@ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
259
292
 
260
293
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
261
294
  _ignore_warning(
262
- "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak")
295
+ "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
296
+ )
263
297
  _ignore_warning(
264
- "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak")
298
+ "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
299
+ )
265
300
 
266
301
 
267
302
  @decorator
@@ -408,7 +443,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
408
443
 
409
444
  assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
410
445
  in_str = einsum_base[: exp_avg.dim()]
411
- out_str = einsum_base[exp_avg.dim(): 2 * exp_avg.dim()]
446
+ out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()]
412
447
 
413
448
  from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
414
449
  if not from_shampoo:
@@ -418,8 +453,9 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
418
453
  out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
419
454
 
420
455
  subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
421
- exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
422
- *[q for q in new_qs if q is not None])
456
+ exp_avg_new = torch.einsum(
457
+ subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
458
+ )
423
459
  copy_stochastic_(exp_avg, exp_avg_new)
424
460
 
425
461
  for q, q_new in zip(Q, new_qs):
@@ -546,6 +582,20 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
546
582
  _compilable_stochastic_add_(x, y, alpha)
547
583
 
548
584
 
585
+ @decorator_knowngood
586
+ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Tensor, divisor: Tensor):
587
+ for x_, y_ in zip(x, y):
588
+ x32 = promote(x_)
589
+ y32 = promote(y_)
590
+ copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
591
+
592
+
593
+ def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
594
+ x, y = list_guard(x, y)
595
+ alpha, divisor = scalar_guard(alpha, divisor, x[0])
596
+ _compilable_stochastic_add_divide_(x, y, alpha, divisor)
597
+
598
+
549
599
  @decorator_knowngood
550
600
  def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
551
601
  for x_, y_ in zip(x, y):
@@ -594,6 +644,20 @@ def promote(x):
594
644
  return x
595
645
 
596
646
 
647
+ def promote_detach(x, should_promote):
648
+ if x is None:
649
+ return x
650
+ if should_promote:
651
+ x = promote(x)
652
+ return x.detach()
653
+
654
+
655
+ def detach(x):
656
+ if isinstance(x, Tensor):
657
+ return x.detach()
658
+ return x
659
+
660
+
597
661
  def min_dtype(xs: List[Tensor]):
598
662
  dtypes = [x.dtype for x in xs]
599
663
  for d in (torch.float32, torch.bfloat16, torch.float16):
@@ -647,25 +711,36 @@ def project(grad, Q, back: bool):
647
711
  return grad
648
712
 
649
713
 
650
- def modify_closure(closure):
651
- """
652
- Modifies the closure function to use create_graph=True in backward().
714
+ @contextlib.contextmanager
715
+ def patch_backward():
716
+ @contextlib.contextmanager
717
+ def _inner(module):
718
+ original = module.backward
653
719
 
654
- Args:
655
- closure: The closure function passed to the optimizer.
720
+ signature = inspect.signature(original)
656
721
 
657
- Returns:
658
- The return value of the modified closure.
659
- """
722
+ def patched_backward(*args, **kwargs):
723
+ new_kwargs = signature.bind(*args)
724
+ new_kwargs.apply_defaults()
725
+ new_kwargs = new_kwargs.arguments
726
+ new_kwargs.update(kwargs)
727
+ new_kwargs["create_graph"] = True
728
+ return original(**new_kwargs)
660
729
 
661
- def patched_backward(self, *args, **kwargs):
662
- kwargs["create_graph"] = True
663
- return original_backward(self, *args, **kwargs)
730
+ module.backward = patched_backward
731
+ yield
732
+ module.backward = original
664
733
 
665
- original_backward = torch.Tensor.backward
734
+ with _inner(torch.Tensor), _inner(torch.autograd):
735
+ yield
666
736
 
667
- with patch.object(torch.Tensor, "backward", patched_backward):
668
- return closure()
737
+
738
+ def hasattr_none(obj, name):
739
+ return getattr(obj, name, None) is not None
740
+
741
+
742
+ class ExactHVPFailed(ValueError):
743
+ pass
669
744
 
670
745
 
671
746
  class StatefulOptimizer(torch.optim.Optimizer):
@@ -682,6 +757,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
682
757
  precond_schedule: Union[Callable, float, None] = None
683
758
  stochastic_schedule: bool = False
684
759
  finite_differences: bool = False
760
+ fallback_to_finite_differences: bool = True
761
+ _fallback_enabled: bool = False
762
+ hvp_interval: int = 1 # grad is faster initially, hvp later
685
763
 
686
764
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
687
765
  super().__init__(params, {**defaults, "foreach": foreach})
@@ -708,29 +786,46 @@ class StatefulOptimizer(torch.optim.Optimizer):
708
786
  old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
709
787
  mars_correction(g_list, old_gs, mars_gamma, beta)
710
788
 
711
- def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
712
- beta1: float = -1.0):
789
+ def split_p_and_g_in_group(
790
+ self,
791
+ group: dict,
792
+ skip_none: bool = True,
793
+ should_promote: bool = True,
794
+ beta1: float = -1.0,
795
+ raw: bool = False,
796
+ ):
713
797
  for p in group["params"]:
798
+ grad = getattr(p, "grad", None)
799
+ if grad is None and skip_none:
800
+ continue
801
+
802
+ p.grad = None
803
+
804
+ if raw:
805
+ yield p, grad
806
+ continue
807
+
714
808
  if p in self.mapping:
715
809
  p_views = self.mapping[p]
716
810
  else:
717
811
  self.mapping[p] = p_views = merge_group(group, p)
718
812
 
719
- grad = getattr(p, "grad", None)
720
- p.grad = None
813
+ vector = getattr(p, "vector", None)
814
+ hessian_vector = getattr(p, "hessian_vector", None)
815
+ p.vector = None
816
+ p.hessian_vector = None
721
817
 
722
- if grad is None:
723
- grad = [getattr(pv, "grad", None) for pv in p_views]
724
- else:
725
- grad = merge_group(group, grad)
818
+ grad, vs, hvs = [
819
+ [None] * len(p_views) if x is None else merge_group(group, x) #
820
+ for x in (grad, vector, hessian_vector)
821
+ ]
726
822
 
727
- for pv, g in zip(p_views, grad):
728
- if skip_none and g is None:
729
- continue
730
- if should_promote:
731
- g = promote(g)
823
+ for pv, g, v, hv in zip(p_views, grad, vs, hvs):
824
+ g = promote_detach(g, should_promote)
732
825
  if beta1 >= 0 and group.get("mars", False):
733
826
  self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
827
+ pv.vector = promote_detach(v, should_promote)
828
+ pv.hessian_vector = promote_detach(hv, should_promote)
734
829
  yield pv, g
735
830
 
736
831
  def state_size(self) -> int:
@@ -794,6 +889,66 @@ class StatefulOptimizer(torch.optim.Optimizer):
794
889
  set_(self.state_(p)["param_ema"], p.data)
795
890
  set_(p.data, ema_clone)
796
891
 
892
+ def _finite_differences_hvp(self, closure):
893
+ with torch.enable_grad():
894
+ loss = closure() # closure without retain_graph=True
895
+
896
+ grads = []
897
+ for group in self.param_groups:
898
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
899
+ grads.append(g)
900
+ p.vector = torch.randn_like(p)
901
+ p.orig = p.data.clone()
902
+ # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
903
+ stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
904
+
905
+ with torch.enable_grad():
906
+ closure()
907
+
908
+ # we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
909
+ # this costs more memory, but the imprecision seems too severe to use the other method
910
+ for group in self.param_groups:
911
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
912
+ p.grad = grads.pop(0)
913
+ stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
914
+ p.hessian_vector = g
915
+ p.data.copy_(p.orig)
916
+ del p.orig
917
+ return loss
918
+
919
+ def _double_backward_hvp(self, closure):
920
+ with torch.enable_grad(), patch_backward():
921
+ loss = closure()
922
+
923
+ params, grads = [], []
924
+ for group in self.param_groups:
925
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
926
+ params.append(p)
927
+ grads.append(g)
928
+
929
+ if not params:
930
+ raise ValueError("No parameter has gradients")
931
+
932
+ vs = [torch.randn_like(p) for p in params]
933
+ with torch.enable_grad():
934
+ try:
935
+ hvs = torch.autograd.grad(grads, params, vs, create_graph=False, retain_graph=False, allow_unused=True)
936
+ except RuntimeError as e:
937
+ raise ExactHVPFailed(str(e.args))
938
+
939
+ unused = []
940
+ for p, g, v, hv in zip(params, grads, vs, hvs):
941
+ p.hessian_vector = detach(hv)
942
+ p.grad = detach(g)
943
+ p.vector = detach(v)
944
+ if hv is None:
945
+ unused.append(list(p.shape))
946
+
947
+ if unused:
948
+ raise ExactHVPFailed(f"Parameters with the following shapes have no 2nd order derivative: {unused}")
949
+
950
+ return loss
951
+
797
952
  def _handle_closure(self, closure):
798
953
  hessian_approx = self.hessian_approx and self._is_preconditioning
799
954
 
@@ -802,56 +957,41 @@ class StatefulOptimizer(torch.optim.Optimizer):
802
957
  raise ValueError("Hessian approximation requires a closure.")
803
958
  return None
804
959
 
805
- if not hessian_approx:
960
+ step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
961
+ if not hessian_approx or step % self.hvp_interval == 0:
806
962
  with torch.enable_grad():
807
963
  loss = closure()
808
964
  return loss
809
965
 
810
- if self.finite_differences:
811
- with torch.enable_grad():
812
- loss = closure() # closure without retain_graph=True
813
-
814
- grads = []
815
- for group in self.param_groups:
816
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
817
- grads.append(g)
818
- p.vector = torch.randn_like(p)
819
- p.orig = p.data.clone()
820
- # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
821
- stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
822
- else:
823
- with torch.enable_grad():
824
- loss = modify_closure(closure)
825
-
826
- if self.finite_differences:
827
- with torch.enable_grad():
828
- closure()
829
-
830
- # we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
831
- # this costs more memory, but the imprecision seems too severe to use the other method
832
- for group in self.param_groups:
833
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
834
- p.grad = grads.pop(0)
835
- stochastic_add_(g, p.grad, -1)
836
- p.hessian_vector = g
837
- p.data.copy_(p.orig)
838
- del p.orig
839
- else:
840
- for group in self.param_groups:
841
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
842
- p.grad = g
843
- params, grads = zip(*[x for group in self.param_groups for x in
844
- self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
845
- vs = [torch.randn_like(p) for p in params]
846
- with torch.enable_grad():
847
- hvs = torch.autograd.grad(grads, params, vs)
848
-
849
- for p, g, v, hv in zip(params, grads, vs, hvs):
850
- p.hessian_vector = hv
851
- p.grad = g
852
- p.vector = v
966
+ if self.finite_differences or self._fallback_enabled:
967
+ return self._finite_differences_hvp(closure)
853
968
 
854
- return loss
969
+ try:
970
+ return self._double_backward_hvp(closure)
971
+ except NotImplementedError as e:
972
+ if not self.fallback_to_finite_differences:
973
+ raise
974
+ if not any(isinstance(arg, str) and _cudnn_double_backward_pattern.match(arg) for arg in e.args):
975
+ raise
976
+ warn_once(
977
+ "CUDNN doesn't support double-backward for some models (including RNNs). " #
978
+ f"Falling back to finite_differences.\n{_fd_error}{e}"
979
+ )
980
+ except RuntimeError as e:
981
+ if not self.fallback_to_finite_differences:
982
+ raise
983
+ if not any(isinstance(arg, str) and _torch_compile_double_backward_pattern.match(arg) for arg in e.args):
984
+ raise
985
+ warn_once(
986
+ f"torch.compile does not support double-backward. Disabling it may be beneficial, depending on "
987
+ f"the model.\n{_fd_error}{e}"
988
+ )
989
+ except ExactHVPFailed as e:
990
+ if not self.fallback_to_finite_differences:
991
+ raise
992
+ warn_once(f"Exact HVP calculation failed.\n{_fd_error}{e}")
993
+ self._fallback_enabled = True
994
+ return self._handle_closure(closure)
855
995
 
856
996
  def step(self, closure: Optional[Callable] = None):
857
997
  if self.precond_schedule is None:
@@ -867,7 +1007,11 @@ class StatefulOptimizer(torch.optim.Optimizer):
867
1007
  self._step(group)
868
1008
  if self.use_ema:
869
1009
  self.ema_update()
870
-
1010
+ for real, views in self.mapping.items():
1011
+ for tensor in (real, *views):
1012
+ for key in ("grad", "vector", "hessian_vector", "orig"):
1013
+ if hasattr(tensor, key):
1014
+ setattr(tensor, key, None)
871
1015
  return loss
872
1016
 
873
1017
 
@@ -887,8 +1031,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
887
1031
 
888
1032
 
889
1033
  @decorator_knowngood
890
- def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
891
- step: Tensor, eps: Tensor, ):
1034
+ def _compilable_adam_(
1035
+ exp_avg: List[Tensor],
1036
+ exp_avg_sq: List[Tensor],
1037
+ grad: List[Tensor],
1038
+ beta1: Tensor,
1039
+ beta2: Tensor,
1040
+ step: Tensor,
1041
+ eps: Tensor,
1042
+ ):
892
1043
  beta1 = beta_debias(beta1, step)
893
1044
  beta2 = beta_debias(beta2, step)
894
1045
 
@@ -899,8 +1050,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
899
1050
  copy_stochastic_list_(grad, u32)
900
1051
 
901
1052
 
902
- def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
903
- eps: float = 1e-8, ):
1053
+ def adam_(
1054
+ exp_avg: List[Tensor],
1055
+ exp_avg_sq: List[Tensor],
1056
+ grad: List[Tensor],
1057
+ beta1: float,
1058
+ beta2: float,
1059
+ step: int,
1060
+ eps: float = 1e-8,
1061
+ ):
904
1062
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
905
1063
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
906
1064
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -908,9 +1066,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
908
1066
 
909
1067
 
910
1068
  @decorator_knowngood
911
- def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
912
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
913
- caution: bool, ):
1069
+ def _fused_compilable_adam_(
1070
+ y: List[Tensor],
1071
+ exp_avg: List[Tensor],
1072
+ exp_avg_sq: List[Tensor],
1073
+ update: List[Tensor],
1074
+ grad: List[Tensor],
1075
+ beta1: Tensor,
1076
+ beta2: Tensor,
1077
+ step: Tensor,
1078
+ decay: Tensor,
1079
+ lr: Tensor,
1080
+ eps: Tensor,
1081
+ caution: bool,
1082
+ ):
914
1083
  beta1 = beta_debias(beta1, step)
915
1084
  beta2 = beta_debias(beta2, step)
916
1085
 
@@ -921,17 +1090,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
921
1090
  _compilable_update_(y, u32, decay, lr, caution, g32)
922
1091
 
923
1092
 
924
- def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
925
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
926
- caution: bool, ):
1093
+ def fused_adam_(
1094
+ y: List[Tensor],
1095
+ exp_avg: List[Tensor],
1096
+ exp_avg_sq: List[Tensor],
1097
+ update: List[Tensor],
1098
+ grad: List[Tensor],
1099
+ beta1: float,
1100
+ beta2: float,
1101
+ step: int,
1102
+ lr: float,
1103
+ eps: float,
1104
+ decay: float,
1105
+ caution: bool,
1106
+ ):
927
1107
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
928
1108
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
929
1109
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
930
1110
 
931
1111
 
932
1112
  @decorator_knowngood
933
- def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
934
- beta2: Tensor, step: Tensor, eps: Tensor, ):
1113
+ def _compilable_laprop_(
1114
+ exp_avg: List[Tensor],
1115
+ exp_avg_sq: List[Tensor],
1116
+ grad: List[Tensor],
1117
+ beta1: Tensor,
1118
+ beta2: Tensor,
1119
+ step: Tensor,
1120
+ eps: Tensor,
1121
+ ):
935
1122
  beta1 = beta_debias(beta1, step)
936
1123
  beta2 = beta_debias(beta2, step)
937
1124
 
@@ -942,8 +1129,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
942
1129
  copy_stochastic_list_(grad, gp32)
943
1130
 
944
1131
 
945
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
946
- eps: float = 1e-8, ):
1132
+ def laprop_(
1133
+ exp_avg: List[Tensor],
1134
+ exp_avg_sq: List[Tensor],
1135
+ grad: List[Tensor],
1136
+ beta1: float,
1137
+ beta2: float,
1138
+ step: int,
1139
+ eps: float = 1e-8,
1140
+ ):
947
1141
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
948
1142
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
949
1143
  _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -951,9 +1145,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
951
1145
 
952
1146
 
953
1147
  @decorator_knowngood
954
- def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
955
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor, caution: bool,
956
- eps: Tensor, ):
1148
+ def _fused_compilable_laprop_(
1149
+ y: List[Tensor],
1150
+ exp_avg: List[Tensor],
1151
+ exp_avg_sq: List[Tensor],
1152
+ update: List[Tensor],
1153
+ grad: List[Tensor],
1154
+ beta1: Tensor,
1155
+ beta2: Tensor,
1156
+ step: Tensor,
1157
+ lr: Tensor,
1158
+ decay: Tensor,
1159
+ caution: bool,
1160
+ eps: Tensor,
1161
+ ):
957
1162
  beta1 = beta_debias(beta1, step)
958
1163
  beta2 = beta_debias(beta2, step)
959
1164
 
@@ -964,9 +1169,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
964
1169
  _compilable_update_(y, u32, decay, lr, caution, gp32)
965
1170
 
966
1171
 
967
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
968
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
969
- eps: float = 1e-8, ):
1172
+ def fused_laprop_(
1173
+ y: List[Tensor],
1174
+ exp_avg: List[Tensor],
1175
+ exp_avg_sq: List[Tensor],
1176
+ update: List[Tensor],
1177
+ grad: List[Tensor],
1178
+ beta1: float,
1179
+ beta2: float,
1180
+ step: int,
1181
+ lr: float,
1182
+ decay: float,
1183
+ caution: bool,
1184
+ eps: float = 1e-8,
1185
+ ):
970
1186
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
971
1187
  beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
972
1188
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
@@ -1040,8 +1256,9 @@ def copy_stochastic_(target: Tensor, source: Tensor):
1040
1256
 
1041
1257
 
1042
1258
  @decorator_knowngood
1043
- def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
1044
- g: List[Optional[Tensor]]):
1259
+ def _compilable_update_(
1260
+ p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]]
1261
+ ):
1045
1262
  for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
1046
1263
  u_ = promote(u_.view_as(p_))
1047
1264
  p32_ = promote(p_)
@@ -1051,8 +1268,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
1051
1268
  copy_stochastic_(p_, p32_)
1052
1269
 
1053
1270
 
1054
- def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
1055
- grad: List[Tensor] = None):
1271
+ def update_param_(
1272
+ param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None
1273
+ ):
1056
1274
  param, update, grad = list_guard(param, update, grad)
1057
1275
  lr = scalar_guard(lr, param[0])
1058
1276
  if not caution:
@@ -1076,28 +1294,74 @@ def _max_idx(x: List[int]):
1076
1294
 
1077
1295
 
1078
1296
  @decorator_knowngood
1079
- def mean_root(x: torch.Tensor, pow: float):
1080
- return stochastic_round_(x, x.float().pow(pow).mean().pow(-1 / pow / 2))
1297
+ def stable_exp(x: Tensor):
1298
+ # fp16:
1299
+ # exp(x) is stable in [-17, 11]
1300
+ # `stable_exp` extends to [-17, 17]
1301
+ # average error (in [-10, 10]) increased from 2.288e-3 to 2.299e-3
1302
+ # fp32:
1303
+ # exp(x) is stable in [-103, 88]
1304
+ # `stable_exp` extends to [-103, 103]
1305
+ # average error (in [-87, 87]) reduced from 3.309-06 to 3.224-06
1306
+ return torch.where(x > 0, 1 / (-x).exp(), x.exp())
1081
1307
 
1082
1308
 
1083
1309
  @decorator_knowngood
1084
- def divided_root(x, y, pow0, pow1):
1085
- mean_x = x.float().pow(pow0).mean().pow(1 / pow0 / 2)
1086
- mean_y = y.float().pow(pow1).mean().pow(-1 / pow1 / 2)
1087
- return stochastic_round_(x, mean_x * mean_y) # multiply here, as we already divide in pow -1
1310
+ def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
1311
+ # 1 / (mean(x ** pow) ** (1 / pow / 2))
1312
+ log_x = x.double().abs().clamp(min=eps).log()
1313
+ log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
1314
+ return stable_exp(-log_mean_x_pow / pow / 2)
1315
+
1316
+
1317
+ @decorator_knowngood
1318
+ def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
1319
+ # mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
1320
+ log_x = x.double().abs().clamp(min=eps).log()
1321
+ log_y = y.double().abs().clamp(min=eps).log()
1322
+
1323
+ x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
1324
+ x_normed = x_normed / pow0 / 2
1325
+
1326
+ y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
1327
+ y_normed = y_normed / pow1 / 2
1088
1328
 
1329
+ return stable_exp(x_normed - y_normed)
1089
1330
 
1090
- def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector):
1331
+
1332
+ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1333
+ automatic_scale = True
1334
+ manual_hint = " Set it manually using `precond_init_scale=0.1`"
1091
1335
  if scale is not None:
1336
+ automatic_scale = False
1092
1337
  warn_once(
1093
- "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics.")
1094
- if scale_scale is not None:
1338
+ "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
1339
+ )
1340
+ if scale_scale is not None and scale_scale != 1:
1095
1341
  warn_once(
1096
- "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale.")
1342
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
1343
+ )
1344
+ elif hessian_vector is None:
1345
+ scale = mean_root(grad, 4) * scale_scale
1346
+ else:
1347
+ scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1348
+ if isinstance(scale, torch.Tensor):
1349
+ scale = scale.item() # slow, but necessary
1350
+ if np.isfinite(scale):
1351
+ if scale > scale_max or scale < 1 / scale_max:
1352
+ warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1097
1353
  return scale
1098
- if hessian_vector is None:
1099
- return mean_root(grad, 4) * scale_scale
1100
- return divided_root(vector, hessian_vector, 2, 4) * scale_scale
1354
+ if not automatic_scale:
1355
+ raise ValueError("The manually set precond_init_scale is not finite")
1356
+
1357
+ for x in (grad, hessian_vector, vector):
1358
+ if x is None:
1359
+ continue
1360
+ if torch.allclose(x, torch.zeros_like(x)).item():
1361
+ raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1362
+ if not torch.isfinite(x).all().item():
1363
+ raise ValueError("Grad or HVP is not finite")
1364
+ raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1101
1365
 
1102
1366
 
1103
1367
  def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
@@ -1108,8 +1372,9 @@ def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None)
1108
1372
  return U, V, d
1109
1373
 
1110
1374
 
1111
- def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector,
1112
- dtype=None):
1375
+ def init_Q_exprs(
1376
+ grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
1377
+ ):
1113
1378
  """
1114
1379
  For a scalar or tensor `grad`, we initialize its preconditioner Q and
1115
1380
  reusable einsum expressions for updating Q and preconditioning gradient.
@@ -1147,8 +1412,10 @@ def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory
1147
1412
  elif memory_save_mode == "all_diag":
1148
1413
  dim_diag = [True for _ in shape]
1149
1414
  else:
1150
- raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1151
- "[None, 'one_diag', 'all_diag', 'smart_one_diag']")
1415
+ raise ValueError(
1416
+ f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1417
+ "[None, 'one_diag', 'all_diag', 'smart_one_diag']"
1418
+ )
1152
1419
 
1153
1420
  Q = []
1154
1421
  piece1A, piece2A, piece3A = ([], "", "")
@@ -1213,8 +1480,16 @@ def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1213
1480
  return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1214
1481
 
1215
1482
 
1216
- def update_lra_precond_(U: List[Tensor], V: List[Tensor], d: List[Tensor], vector: Tensor, hessian_vector: Tensor,
1217
- eps: float, step: float, delayed: bool, ):
1483
+ def update_lra_precond_(
1484
+ U: List[Tensor],
1485
+ V: List[Tensor],
1486
+ d: List[Tensor],
1487
+ vector: Tensor,
1488
+ hessian_vector: Tensor,
1489
+ eps: float,
1490
+ step: float,
1491
+ delayed: bool,
1492
+ ):
1218
1493
  """
1219
1494
  Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1220
1495
  """
@@ -1293,7 +1568,7 @@ def lra_precond(U, V, d, g):
1293
1568
 
1294
1569
 
1295
1570
  @decorator_knowngood
1296
- def dampen_grad(g: Tensor, damp: float = 2 ** -13):
1571
+ def dampen_grad(g: Tensor, damp: float = 2**-13):
1297
1572
  # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
1298
1573
  v = torch.randn_like(g)
1299
1574
  return v, g + damp * g.abs().mean() * v
@@ -1306,7 +1581,7 @@ def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor,
1306
1581
  update = update.flatten()
1307
1582
  for p in params:
1308
1583
  size = p.numel()
1309
- copy_stochastic_(p, update[start: start + size].view_as(p))
1584
+ copy_stochastic_(p, update[start : start + size].view_as(p))
1310
1585
  start += size
1311
1586
 
1312
1587
 
@@ -1316,7 +1591,7 @@ def apply_flat_update(params: List[Tensor], update: Tensor):
1316
1591
  update = update.flatten()
1317
1592
  for p in params:
1318
1593
  size = p.numel()
1319
- copy_stochastic_(p, update[start: start + size].view_as(p))
1594
+ copy_stochastic_(p, update[start : start + size].view_as(p))
1320
1595
  start += size
1321
1596
 
1322
1597
 
@@ -1326,7 +1601,7 @@ def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1326
1601
  update = update.flatten()
1327
1602
  for p in params:
1328
1603
  size = p.numel()
1329
- stochastic_add_([p], [update[start: start + size].view_as(p)], alpha)
1604
+ stochastic_add_([p], [update[start : start + size].view_as(p)], alpha)
1330
1605
  start += size
1331
1606
 
1332
1607
 
@@ -1337,16 +1612,19 @@ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1337
1612
  update = update.flatten()
1338
1613
  for p in params:
1339
1614
  size = p.numel()
1340
- outputs.append(update[start: start + size].view_as(p))
1615
+ outputs.append(update[start : start + size].view_as(p))
1341
1616
  start += size
1342
1617
  return outputs
1343
1618
 
1344
1619
 
1620
+ @decorator_knowngood
1345
1621
  def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1346
- return torch.cat([i.flatten(0, -1 - remaining) for i in x], 0)
1622
+ last_dim = x[0].shape[-remaining:] if remaining else []
1623
+ return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
1347
1624
 
1348
1625
 
1349
- def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
1626
+ @decorator_knowngood
1627
+ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1350
1628
  vs = []
1351
1629
  gs = []
1352
1630
  for g_ in g:
@@ -1356,22 +1634,27 @@ def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
1356
1634
  return flatten(vs), flatten(gs)
1357
1635
 
1358
1636
 
1359
- def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1637
+ @decorator_knowngood
1638
+ def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1639
+ md = min_dtype(args)
1640
+ return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1641
+
1642
+
1643
+ def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1360
1644
  order = G.dim()
1361
- if V is None:
1362
- V, G = dampen_grad(G)
1363
- conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
1364
- md = min_dtype(Q + [G])
1365
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1366
- Q = [promote(q) for q in Q]
1645
+ if order > 1:
1646
+ conjB = conjB.view_as(G).permute(*range(1, order), 0)
1647
+ conjB = conjB.to(promote(G.dtype))
1648
+ A = casted_einsum(exprA, *Q, G)
1367
1649
  for i, q in enumerate(Q):
1650
+ q = promote(q)
1368
1651
  if q.dim() <= 1:
1369
1652
  conjB /= q
1370
1653
  else:
1371
- conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
1372
- conjB)
1654
+ solved = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)).contiguous(), upper=True, left=False)
1655
+ conjB = solved.reshape_as(conjB)
1373
1656
  if i < order - 1:
1374
- conjB = torch.transpose(conjB, i, order - 1)
1657
+ conjB = conjB.transpose(i, -1)
1375
1658
  return A, conjB
1376
1659
 
1377
1660
 
@@ -1407,9 +1690,12 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1407
1690
  term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1408
1691
  term1 = torch.mm(term1, q.to(term1.dtype))
1409
1692
  if store_triu_as_line:
1410
- term1 = triu_to_line([term1])[0][1]
1411
- o = o[1]
1412
- stochastic_add_(o, term1, -1)
1693
+ term1 = triu_to_line([term1])[0][1] # Convert update to line format
1694
+ # Apply update directly to the tensor part of the state tuple o[1]
1695
+ stochastic_add_(o[1], term1, -1)
1696
+ else:
1697
+ # Apply update to the state tensor o
1698
+ stochastic_add_(o, term1, -1)
1413
1699
 
1414
1700
 
1415
1701
  @decorator_knowngood
@@ -1619,8 +1905,9 @@ def warn_once(msg):
1619
1905
  _warned.add(msg)
1620
1906
 
1621
1907
 
1622
- def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1623
- name: str = "cumulative_prob"):
1908
+ def psgd_should_update(
1909
+ group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
1910
+ ):
1624
1911
  group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
1625
1912
  if not isinstance(prob, float):
1626
1913
  prob = prob(group[f"{name}_prob_step"])
@@ -1632,8 +1919,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1632
1919
 
1633
1920
 
1634
1921
  @decorator_knowngood
1635
- def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1636
- cast: bool = True):
1922
+ def precond_grad_cached_(
1923
+ expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
1924
+ ):
1637
1925
  if caution:
1638
1926
  ea = _compilable_cautioning(grad, ea)
1639
1927
  md = min_dtype(list(cached_q) + [ea])
@@ -1753,12 +2041,79 @@ def merge_group(group, *tensors):
1753
2041
 
1754
2042
  out = []
1755
2043
  for t in tensors:
1756
- append_or_extend(out,
1757
- dim_merger(t, group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
1758
- group.get("split", False), ), )
2044
+ append_or_extend(
2045
+ out,
2046
+ dim_merger(
2047
+ t,
2048
+ group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
2049
+ group.get("split", False),
2050
+ ),
2051
+ )
1759
2052
  return out
1760
2053
 
1761
2054
 
2055
+ @decorator_knowngood
2056
+ def _compilable_d_adapt_(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
2057
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2058
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2059
+ next_d = d * (g * s).sum()
2060
+ s = s + u * d
2061
+ next_d = next_d / s.abs().sum()
2062
+ next_d = torch.maximum(next_d, d)
2063
+ copy_stochastic_(u_, u * d)
2064
+ copy_stochastic_(d_, next_d)
2065
+ copy_stochastic_(s_, s)
2066
+
2067
+
2068
+ def d_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
2069
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2070
+ _compilable_d_adapt_(grads, update, state, delta)
2071
+
2072
+
2073
+ @decorator_knowngood
2074
+ def _compilable_lr_adapt_(
2075
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
2076
+ ):
2077
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2078
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2079
+ lr_grad = d.sigmoid()
2080
+ lr_grad = lr_grad * (1 - lr_grad)
2081
+ lr_grad = lr_grad * (s * g).mean()
2082
+ d = d - lr_grad * lr_lr
2083
+ copy_stochastic_(d_, d)
2084
+ copy_stochastic_(u_, u * d.sigmoid())
2085
+ copy_stochastic_(s_, u)
2086
+
2087
+
2088
+ def lr_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float):
2089
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2090
+ lr_lr = scalar_guard(lr_lr, grads[0])
2091
+ _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
2092
+
2093
+
2094
+ @decorator_knowngood
2095
+ def _compilable_pointwise_lr_adapt_(
2096
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
2097
+ ):
2098
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2099
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2100
+ lr_grad = d.sigmoid()
2101
+ lr_grad = lr_grad * (1 - lr_grad)
2102
+ lr_grad = lr_grad * s * g
2103
+ d = d - lr_grad * lr_lr
2104
+ copy_stochastic_(d_, d)
2105
+ copy_stochastic_(u_, u * d.sigmoid())
2106
+ copy_stochastic_(s_, u)
2107
+
2108
+
2109
+ def pointwise_lr_adaptation(
2110
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float
2111
+ ):
2112
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2113
+ lr_lr = scalar_guard(lr_lr, grads[0])
2114
+ _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
2115
+
2116
+
1762
2117
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1763
2118
  optimizers = {}
1764
2119
 
@@ -1781,8 +2136,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1781
2136
 
1782
2137
  o = optimizer(parameters, *args, **kwargs)
1783
2138
  step_fn = o.step
1784
- o.step = functools.partial(warn_once,
1785
- msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
2139
+ o.step = functools.partial(
2140
+ warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything."
2141
+ )
1786
2142
 
1787
2143
  def _step(p: Tensor):
1788
2144
  seen_params.add(p)