heavyball 1.6.3__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -1,11 +1,13 @@
1
+ import contextlib
1
2
  import functools
2
3
  import gc
4
+ import inspect
3
5
  import math
4
6
  import random
7
+ import re
5
8
  import string
6
9
  import warnings
7
- from typing import List, Optional, Tuple, Callable, Union
8
- from unittest.mock import patch
10
+ from typing import Callable, List, Optional, Tuple, Union
9
11
 
10
12
  import numpy as np
11
13
  import torch
@@ -15,19 +17,22 @@ from torch._dynamo.exc import TorchDynamoException
15
17
  from torch.backends import cudnn, opt_einsum
16
18
  from torch.utils._pytree import tree_map
17
19
 
18
- config.cache_size_limit = 2 ** 16
19
-
20
- np.warnings = warnings
20
+ config.cache_size_limit = 2**16
21
21
 
22
22
  compile_mode = "max-autotune-no-cudagraphs"
23
23
  dynamic = False
24
24
  compile_mode_recommended_to_none = None
25
- zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
25
+ zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
26
26
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
-
28
- base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
29
- 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
30
- 'weight_decay': 1e-4}
27
+ _cudnn_double_backward_pattern = re.compile(
28
+ r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
29
+ )
30
+ _torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
31
+ _fd_error = (
32
+ "You can accelerate startup by globally enabling finite_differences first " #
33
+ "(via opt.finite_differences=True or by subclassing it)\n"
34
+ "Original Error: "
35
+ )
31
36
 
32
37
 
33
38
  def decorator(func):
@@ -35,7 +40,6 @@ def decorator(func):
35
40
 
36
41
  @functools.wraps(func)
37
42
  def _fn(*args, **kwargs):
38
- disable = compile_mode_recommended_to_none is None
39
43
  if is_compiling() or compile_mode_recommended_to_none is None:
40
44
  return func(*args, **kwargs)
41
45
  nonlocal compiled
@@ -65,8 +69,17 @@ einsum_base = string.ascii_lowercase
65
69
 
66
70
 
67
71
  @decorator_knowngood
68
- def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
69
- beta1: Tensor, decay: float, grad: List[Tensor], caution):
72
+ def _compilable_schedule_free_(
73
+ p: List[Tensor],
74
+ z: List[Tensor],
75
+ ckp1: Tensor,
76
+ update: List[Tensor],
77
+ lr: Tensor,
78
+ beta1: Tensor,
79
+ decay: float,
80
+ grad: List[Tensor],
81
+ caution,
82
+ ):
70
83
  for op, oz, u_, g_ in zip(p, z, update, grad):
71
84
  u_ = u_.view_as(op)
72
85
  p_, z_, u_ = map(promote, (op, oz, u_))
@@ -81,9 +94,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
81
94
  copy_stochastic_(oz, z_)
82
95
 
83
96
 
84
- def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
85
- z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
86
- step: int = 0, decay: float = 0.0):
97
+ def schedule_free_(
98
+ lr: float,
99
+ weight_lr_power: float,
100
+ weight_sum: float,
101
+ beta1: float,
102
+ parameters: List[Tensor],
103
+ z: List[Tensor],
104
+ update: List[Tensor],
105
+ grad: List[Tensor],
106
+ caution: bool = False,
107
+ r: float = 0.0,
108
+ step: int = 0,
109
+ decay: float = 0.0,
110
+ ):
87
111
  weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
88
112
  weight_sum = weight_sum + weight
89
113
 
@@ -156,7 +180,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
156
180
 
157
181
 
158
182
  def beta_debias(beta, step):
159
- return 1 - (1 - beta) / (1 - beta ** step)
183
+ return 1 - (1 - beta) / (1 - beta**step)
160
184
 
161
185
 
162
186
  def eps_sqrt(item, eps):
@@ -164,8 +188,9 @@ def eps_sqrt(item, eps):
164
188
 
165
189
 
166
190
  @decorator_knowngood
167
- def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
168
- out: List[Optional[Tensor]]):
191
+ def _compilable_exp_avg_sq_(
192
+ state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]]
193
+ ):
169
194
  g32 = promote(grad)
170
195
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
171
196
 
@@ -226,8 +251,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
226
251
  copy_stochastic_list_(gradients, g32)
227
252
 
228
253
 
229
- def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
230
- minimum: float = 1e-3, eps: float = 1e-8):
254
+ def adaptive_gradient_clipping_(
255
+ parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8
256
+ ):
231
257
  if clip_val <= 0:
232
258
  return gradients
233
259
  parameters, gradients = list_guard(parameters, gradients)
@@ -253,23 +279,24 @@ def clean():
253
279
 
254
280
 
255
281
  def _ignore_warning(msg):
256
- warnings.filterwarnings('ignore', f'.*{msg}.*')
282
+ warnings.filterwarnings("ignore", f".*{msg}.*")
257
283
 
258
284
 
259
- def set_torch(benchmark_limit: int = 32):
285
+ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
260
286
  cudnn.benchmark = True
261
287
  cudnn.deterministic = False
262
288
  cudnn.benchmark_limit = benchmark_limit
263
289
  torch.use_deterministic_algorithms(False)
264
290
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
265
- opt_einsum.enabled = False
266
- opt_einsum.strategy = "auto"
291
+ opt_einsum.set_flags(True, einsum_strategy)
267
292
 
268
293
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
269
294
  _ignore_warning(
270
- 'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
295
+ "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak"
296
+ )
271
297
  _ignore_warning(
272
- 'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
298
+ "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak"
299
+ )
273
300
 
274
301
 
275
302
  @decorator
@@ -277,7 +304,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
277
304
  assert len(G.shape) == 2
278
305
  a, b, c = (3.4445, -4.7750, 2.0315)
279
306
  X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
280
- X /= (X.norm() + eps) # ensure top singular value <= 1
307
+ X /= X.norm() + eps # ensure top singular value <= 1
281
308
  if G.size(0) > G.size(1):
282
309
  X = X.T
283
310
  for _ in range(steps):
@@ -290,10 +317,10 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
290
317
 
291
318
 
292
319
  def ortho(x):
293
- if zeroth_power_mode == 'qr':
320
+ if zeroth_power_mode == "qr":
294
321
  return torch.linalg.qr(x).Q
295
- if zeroth_power_mode == 'svd':
296
- u, s, v = torch.linalg.svd(x)
322
+ if zeroth_power_mode == "svd":
323
+ u, _s, v = torch.linalg.svd(x)
297
324
  return u @ v.T
298
325
  raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
299
326
 
@@ -351,12 +378,12 @@ def _compilable_grafting(magnitude, direction):
351
378
 
352
379
  @decorator_knowngood
353
380
  def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
354
- if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
381
+ if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
355
382
  y = zeropower_via_newtonschulz5(x, 5)
356
- elif mode == 'qr':
383
+ elif mode == "qr":
357
384
  y = torch.linalg.qr(promote(x)).Q
358
- elif mode == 'svd':
359
- u, s, v = torch.linalg.svd(promote(x))
385
+ elif mode == "svd":
386
+ u, _s, v = torch.linalg.svd(promote(x))
360
387
  y = u @ v.T
361
388
  else:
362
389
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
@@ -403,7 +430,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
403
430
  q_old = promote(q.data)
404
431
 
405
432
  tmp = m @ q_old
406
- est_eig = torch.einsum('ij,ij->j', q_old, tmp)
433
+ est_eig = torch.einsum("ij,ij->j", q_old, tmp)
407
434
  sort_idx = torch.argsort(est_eig, descending=True)
408
435
 
409
436
  tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
@@ -415,19 +442,20 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
415
442
  return
416
443
 
417
444
  assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
418
- in_str = einsum_base[:exp_avg.dim()]
419
- out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
445
+ in_str = einsum_base[: exp_avg.dim()]
446
+ out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()]
420
447
 
421
448
  from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
422
449
  if not from_shampoo:
423
450
  return
424
451
 
425
- to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
426
- out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
452
+ to_shampoo = ",".join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
453
+ out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
427
454
 
428
- subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
429
- exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
430
- *[q for q in new_qs if q is not None])
455
+ subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
456
+ exp_avg_new = torch.einsum(
457
+ subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None]
458
+ )
431
459
  copy_stochastic_(exp_avg, exp_avg_new)
432
460
 
433
461
  for q, q_new in zip(Q, new_qs):
@@ -453,11 +481,11 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
453
481
  while True:
454
482
  try:
455
483
  eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
456
- eigval, eigvec = torch.linalg.eigh(m + eps * eye)
484
+ _eigval, eigvec = torch.linalg.eigh(m + eps * eye)
457
485
  eigvec = eigvec.to(device=device, dtype=dtype)
458
486
  break
459
487
  except torch.OutOfMemoryError:
460
- if m.device.type == 'cpu':
488
+ if m.device.type == "cpu":
461
489
  raise
462
490
  else:
463
491
  m = m.cpu()
@@ -489,21 +517,21 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
489
517
 
490
518
  def get_beta1(group):
491
519
  beta = None
492
- if 'beta' in group:
493
- beta = group['beta']
494
- if beta is None and 'betas' in group:
495
- beta = group['betas'][0]
520
+ if "beta" in group:
521
+ beta = group["beta"]
522
+ if beta is None and "betas" in group:
523
+ beta = group["betas"][0]
496
524
  if beta is None:
497
525
  raise ValueError("Beta not found in group.")
498
526
  return beta
499
527
 
500
528
 
501
529
  def get_beta2(group):
502
- if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
530
+ if "palm" in group and group["palm"] is True and "beta2_scale" in group:
503
531
  step = max(group.get("step", 1), 1)
504
- return 1 - step ** -group['beta2_scale']
505
- if 'betas' in group:
506
- return group['betas'][1]
532
+ return 1 - step ** -group["beta2_scale"]
533
+ if "betas" in group:
534
+ return group["betas"][1]
507
535
  raise ValueError("Beta2 not found in group.")
508
536
 
509
537
 
@@ -554,6 +582,20 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
554
582
  _compilable_stochastic_add_(x, y, alpha)
555
583
 
556
584
 
585
+ @decorator_knowngood
586
+ def _compilable_stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Tensor, divisor: Tensor):
587
+ for x_, y_ in zip(x, y):
588
+ x32 = promote(x_)
589
+ y32 = promote(y_)
590
+ copy_stochastic_(x_, (x32 + y32 * alpha) / divisor)
591
+
592
+
593
+ def stochastic_add_divide_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1, divisor: float = 1):
594
+ x, y = list_guard(x, y)
595
+ alpha, divisor = scalar_guard(alpha, divisor, x[0])
596
+ _compilable_stochastic_add_divide_(x, y, alpha, divisor)
597
+
598
+
557
599
  @decorator_knowngood
558
600
  def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
559
601
  for x_, y_ in zip(x, y):
@@ -580,9 +622,9 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
580
622
  if not isinstance(m, Tensor):
581
623
  continue
582
624
  b = einsum_base[idx]
583
- g0 = einsum_base[:grad.dim()]
625
+ g0 = einsum_base[: grad.dim()]
584
626
  g1 = g0.replace(b, b.upper())
585
- outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
627
+ outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
586
628
  stochastic_lerp_(m, outer_product, 1 - beta)
587
629
 
588
630
 
@@ -602,6 +644,20 @@ def promote(x):
602
644
  return x
603
645
 
604
646
 
647
+ def promote_detach(x, should_promote):
648
+ if x is None:
649
+ return x
650
+ if should_promote:
651
+ x = promote(x)
652
+ return x.detach()
653
+
654
+
655
+ def detach(x):
656
+ if isinstance(x, Tensor):
657
+ return x.detach()
658
+ return x
659
+
660
+
605
661
  def min_dtype(xs: List[Tensor]):
606
662
  dtypes = [x.dtype for x in xs]
607
663
  for d in (torch.float32, torch.bfloat16, torch.float16):
@@ -623,19 +679,19 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
623
679
  """
624
680
  Initializes the preconditioner matrices (L and R in the paper).
625
681
  """
626
- state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
682
+ state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper).
627
683
  if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
628
684
  for sh in grad.shape:
629
685
  if sh > max_precond_dim or sh == 1:
630
686
  # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
631
- state['GG'].append(None)
687
+ state["GG"].append(None)
632
688
  else:
633
- state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
689
+ state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
634
690
  else:
635
- state['GG'].append(None)
691
+ state["GG"].append(None)
636
692
 
637
- update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
638
- state['Q'] = get_orthogonal_matrix(state['GG'])
693
+ update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0)
694
+ state["Q"] = get_orthogonal_matrix(state["GG"])
639
695
 
640
696
 
641
697
  @decorator
@@ -646,34 +702,45 @@ def project(grad, Q, back: bool):
646
702
  :param back: whether to project to Shampoo eigenbases or back to original space
647
703
  :return:
648
704
  """
649
- param = einsum_base[:grad.dim()]
650
- preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
705
+ param = einsum_base[: grad.dim()]
706
+ preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
651
707
  if preconditioners:
652
- out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
653
- out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
708
+ out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
709
+ out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
654
710
  grad = out.to(grad.dtype)
655
711
  return grad
656
712
 
657
713
 
658
- def modify_closure(closure):
659
- """
660
- Modifies the closure function to use create_graph=True in backward().
714
+ @contextlib.contextmanager
715
+ def patch_backward():
716
+ @contextlib.contextmanager
717
+ def _inner(module):
718
+ original = module.backward
661
719
 
662
- Args:
663
- closure: The closure function passed to the optimizer.
720
+ signature = inspect.signature(original)
721
+
722
+ def patched_backward(*args, **kwargs):
723
+ new_kwargs = signature.bind(*args)
724
+ new_kwargs.apply_defaults()
725
+ new_kwargs = new_kwargs.arguments
726
+ new_kwargs.update(kwargs)
727
+ new_kwargs["create_graph"] = True
728
+ return original(**new_kwargs)
729
+
730
+ module.backward = patched_backward
731
+ yield
732
+ module.backward = original
733
+
734
+ with _inner(torch.Tensor), _inner(torch.autograd):
735
+ yield
664
736
 
665
- Returns:
666
- The return value of the modified closure.
667
- """
668
737
 
669
- def patched_backward(self, *args, **kwargs):
670
- kwargs['create_graph'] = True
671
- return original_backward(self, *args, **kwargs)
738
+ def hasattr_none(obj, name):
739
+ return getattr(obj, name, None) is not None
672
740
 
673
- original_backward = torch.Tensor.backward
674
741
 
675
- with patch.object(torch.Tensor, 'backward', patched_backward):
676
- return closure()
742
+ class ExactHVPFailed(ValueError):
743
+ pass
677
744
 
678
745
 
679
746
  class StatefulOptimizer(torch.optim.Optimizer):
@@ -683,18 +750,22 @@ class StatefulOptimizer(torch.optim.Optimizer):
683
750
  The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
684
751
  Further notice that both methods have different numerics outputs
685
752
  """
753
+
686
754
  ema_decay: float = 0.001
687
755
  compile_step: bool = False
688
756
  hessian_approx: bool = False
689
757
  precond_schedule: Union[Callable, float, None] = None
690
758
  stochastic_schedule: bool = False
691
759
  finite_differences: bool = False
760
+ fallback_to_finite_differences: bool = True
761
+ _fallback_enabled: bool = False
762
+ hvp_interval: int = 1 # grad is faster initially, hvp later
692
763
 
693
764
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
694
- super().__init__(params, {**defaults, 'foreach': foreach})
765
+ super().__init__(params, {**defaults, "foreach": foreach})
695
766
  self.use_ema = use_ema
696
767
  self.mapping = {}
697
- self._inner_group = {'stochastic_schedule': self.stochastic_schedule}
768
+ self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
698
769
  self._precond_rng = random.Random(0x12312)
699
770
  self._is_preconditioning = None
700
771
 
@@ -710,34 +781,51 @@ class StatefulOptimizer(torch.optim.Optimizer):
710
781
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
711
782
  for p, g in zip(p_list, g_list):
712
783
  state = self.state_(p)
713
- if 'mars_old_grad' not in state:
714
- state['mars_old_grad'] = torch.zeros_like(g)
715
- old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
784
+ if "mars_old_grad" not in state:
785
+ state["mars_old_grad"] = torch.zeros_like(g)
786
+ old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
716
787
  mars_correction(g_list, old_gs, mars_gamma, beta)
717
788
 
718
- def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
719
- beta1: float = -1.0):
789
+ def split_p_and_g_in_group(
790
+ self,
791
+ group: dict,
792
+ skip_none: bool = True,
793
+ should_promote: bool = True,
794
+ beta1: float = -1.0,
795
+ raw: bool = False,
796
+ ):
720
797
  for p in group["params"]:
798
+ grad = getattr(p, "grad", None)
799
+ if grad is None and skip_none:
800
+ continue
801
+
802
+ p.grad = None
803
+
804
+ if raw:
805
+ yield p, grad
806
+ continue
807
+
721
808
  if p in self.mapping:
722
809
  p_views = self.mapping[p]
723
810
  else:
724
811
  self.mapping[p] = p_views = merge_group(group, p)
725
812
 
726
- grad = getattr(p, 'grad', None)
727
- p.grad = None
728
-
729
- if grad is None:
730
- grad = [getattr(pv, 'grad', None) for pv in p_views]
731
- else:
732
- grad = merge_group(group, grad)
733
-
734
- for pv, g in zip(p_views, grad):
735
- if skip_none and g is None:
736
- continue
737
- if should_promote:
738
- g = promote(g)
739
- if beta1 >= 0 and group.get('mars', False):
740
- self.mars_correct_list(group, [pv], [g], group['mars_gamma'], beta1)
813
+ vector = getattr(p, "vector", None)
814
+ hessian_vector = getattr(p, "hessian_vector", None)
815
+ p.vector = None
816
+ p.hessian_vector = None
817
+
818
+ grad, vs, hvs = [
819
+ [None] * len(p_views) if x is None else merge_group(group, x) #
820
+ for x in (grad, vector, hessian_vector)
821
+ ]
822
+
823
+ for pv, g, v, hv in zip(p_views, grad, vs, hvs):
824
+ g = promote_detach(g, should_promote)
825
+ if beta1 >= 0 and group.get("mars", False):
826
+ self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
827
+ pv.vector = promote_detach(v, should_promote)
828
+ pv.hessian_vector = promote_detach(hv, should_promote)
741
829
  yield pv, g
742
830
 
743
831
  def state_size(self) -> int:
@@ -759,48 +847,108 @@ class StatefulOptimizer(torch.optim.Optimizer):
759
847
  def ema_update(self):
760
848
  with torch.no_grad():
761
849
  for group in self.param_groups:
762
- active_p = [p for p in group['params']]
850
+ active_p = [p for p in group["params"]]
763
851
 
764
852
  if not active_p:
765
853
  return
766
854
 
767
- k = group['ema_step'] = group.get('ema_step', -1) + 1
855
+ k = group["ema_step"] = group.get("ema_step", -1) + 1
768
856
 
769
857
  for p in active_p:
770
- if 'param_ema' not in self.state_(p):
771
- self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
858
+ if "param_ema" not in self.state_(p):
859
+ self.state_(p)["param_ema"] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
772
860
 
773
- y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
861
+ y, param_ema = zip(*[(p.data, self.state_(p)["param_ema"]) for p in active_p])
774
862
  torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
775
863
 
776
864
  def copy_emas_to_params(self):
777
865
  with torch.no_grad():
778
866
  for group in self.param_groups:
779
- active_p = [p for p in group['params']]
867
+ active_p = [p for p in group["params"]]
780
868
 
781
869
  if not active_p:
782
870
  return
783
871
 
784
872
  for p in active_p:
785
- if 'param_ema' in self.state_(p):
873
+ if "param_ema" in self.state_(p):
786
874
  p_clone = p.data.clone()
787
- set_(p.data, self.state_(p)['param_ema'])
788
- set_(self.state_(p)['param_ema'], p_clone)
875
+ set_(p.data, self.state_(p)["param_ema"])
876
+ set_(self.state_(p)["param_ema"], p_clone)
789
877
 
790
878
  def copy_params_to_emas(self):
791
879
  with torch.no_grad():
792
880
  for group in self.param_groups:
793
- active_p = [p for p in group['params']]
881
+ active_p = [p for p in group["params"]]
794
882
 
795
883
  if not active_p:
796
884
  return
797
885
 
798
886
  for p in active_p:
799
- if 'param_ema' in self.state_(p):
800
- ema_clone = self.state_(p)['param_ema'].data.clone()
801
- set_(self.state_(p)['param_ema'], p.data)
887
+ if "param_ema" in self.state_(p):
888
+ ema_clone = self.state_(p)["param_ema"].data.clone()
889
+ set_(self.state_(p)["param_ema"], p.data)
802
890
  set_(p.data, ema_clone)
803
891
 
892
+ def _finite_differences_hvp(self, closure):
893
+ with torch.enable_grad():
894
+ loss = closure() # closure without retain_graph=True
895
+
896
+ grads = []
897
+ for group in self.param_groups:
898
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
899
+ grads.append(g)
900
+ p.vector = torch.randn_like(p)
901
+ p.orig = p.data.clone()
902
+ # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
903
+ stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
904
+
905
+ with torch.enable_grad():
906
+ closure()
907
+
908
+ # we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
909
+ # this costs more memory, but the imprecision seems too severe to use the other method
910
+ for group in self.param_groups:
911
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
912
+ p.grad = grads.pop(0)
913
+ stochastic_add_(g, p.grad, -1) # technically, we have to divide by the scale here
914
+ p.hessian_vector = g
915
+ p.data.copy_(p.orig)
916
+ del p.orig
917
+ return loss
918
+
919
+ def _double_backward_hvp(self, closure):
920
+ with torch.enable_grad(), patch_backward():
921
+ loss = closure()
922
+
923
+ params, grads = [], []
924
+ for group in self.param_groups:
925
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
926
+ params.append(p)
927
+ grads.append(g)
928
+
929
+ if not params:
930
+ raise ValueError("No parameter has gradients")
931
+
932
+ vs = [torch.randn_like(p) for p in params]
933
+ with torch.enable_grad():
934
+ try:
935
+ hvs = torch.autograd.grad(grads, params, vs, create_graph=False, retain_graph=False, allow_unused=True)
936
+ except RuntimeError as e:
937
+ raise ExactHVPFailed(str(e.args))
938
+
939
+ unused = []
940
+ for p, g, v, hv in zip(params, grads, vs, hvs):
941
+ p.hessian_vector = detach(hv)
942
+ p.grad = detach(g)
943
+ p.vector = detach(v)
944
+ if hv is None:
945
+ unused.append(list(p.shape))
946
+
947
+ if unused:
948
+ raise ExactHVPFailed(f"Parameters with the following shapes have no 2nd order derivative: {unused}")
949
+
950
+ return loss
951
+
804
952
  def _handle_closure(self, closure):
805
953
  hessian_approx = self.hessian_approx and self._is_preconditioning
806
954
 
@@ -809,53 +957,41 @@ class StatefulOptimizer(torch.optim.Optimizer):
809
957
  raise ValueError("Hessian approximation requires a closure.")
810
958
  return None
811
959
 
812
- if not hessian_approx:
960
+ step = self._inner_group["total_hvp_steps"] = self._inner_group.get("total_hvp_steps", 0) + 1
961
+ if not hessian_approx or step % self.hvp_interval == 0:
813
962
  with torch.enable_grad():
814
963
  loss = closure()
815
964
  return loss
816
965
 
817
- if self.finite_differences:
818
- with torch.enable_grad():
819
- loss = closure() # closure without retain_graph=True
820
-
821
- grads = []
822
- for group in self.param_groups:
823
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
824
- grads.append(g)
825
- p.vector = torch.randn_like(p)
826
- p.orig = p.data.clone()
827
- stochastic_add_(p.data, p.vector, tiny_bf16)
828
- else:
829
- with torch.enable_grad():
830
- loss = modify_closure(closure)
831
-
832
- if self.finite_differences:
833
- with torch.enable_grad():
834
- closure()
835
-
836
- for group in self.param_groups:
837
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
838
- p.grad = grads.pop(0)
839
- stochastic_add_(g, p.grad, -1)
840
- p.hessian_vector = g
841
- p.data.copy_(p.orig)
842
- del p.orig
843
- else:
844
- for group in self.param_groups:
845
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
846
- p.grad = g
847
- params, grads = zip(*[x for group in self.param_groups for x in
848
- self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
849
- vs = [torch.randn_like(p) for p in params]
850
- with torch.enable_grad():
851
- hvs = torch.autograd.grad(grads, params, vs)
852
-
853
- for p, g, v, hv in zip(params, grads, vs, hvs):
854
- p.hessian_vector = hv
855
- p.grad = g
856
- p.vector = v
857
-
858
- return loss
966
+ if self.finite_differences or self._fallback_enabled:
967
+ return self._finite_differences_hvp(closure)
968
+
969
+ try:
970
+ return self._double_backward_hvp(closure)
971
+ except NotImplementedError as e:
972
+ if not self.fallback_to_finite_differences:
973
+ raise
974
+ if not any(isinstance(arg, str) and _cudnn_double_backward_pattern.match(arg) for arg in e.args):
975
+ raise
976
+ warn_once(
977
+ "CUDNN doesn't support double-backward for some models (including RNNs). " #
978
+ f"Falling back to finite_differences.\n{_fd_error}{e}"
979
+ )
980
+ except RuntimeError as e:
981
+ if not self.fallback_to_finite_differences:
982
+ raise
983
+ if not any(isinstance(arg, str) and _torch_compile_double_backward_pattern.match(arg) for arg in e.args):
984
+ raise
985
+ warn_once(
986
+ f"torch.compile does not support double-backward. Disabling it may be beneficial, depending on "
987
+ f"the model.\n{_fd_error}{e}"
988
+ )
989
+ except ExactHVPFailed as e:
990
+ if not self.fallback_to_finite_differences:
991
+ raise
992
+ warn_once(f"Exact HVP calculation failed.\n{_fd_error}{e}")
993
+ self._fallback_enabled = True
994
+ return self._handle_closure(closure)
859
995
 
860
996
  def step(self, closure: Optional[Callable] = None):
861
997
  if self.precond_schedule is None:
@@ -867,11 +1003,15 @@ class StatefulOptimizer(torch.optim.Optimizer):
867
1003
  # we assume that parameters are constant and that there are no excessive recompiles
868
1004
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
869
1005
  for group in self.param_groups:
870
- group['is_preconditioning'] = self._is_preconditioning
1006
+ group["is_preconditioning"] = self._is_preconditioning
871
1007
  self._step(group)
872
1008
  if self.use_ema:
873
1009
  self.ema_update()
874
-
1010
+ for real, views in self.mapping.items():
1011
+ for tensor in (real, *views):
1012
+ for key in ("grad", "vector", "hessian_vector", "orig"):
1013
+ if hasattr(tensor, key):
1014
+ setattr(tensor, key, None)
875
1015
  return loss
876
1016
 
877
1017
 
@@ -891,8 +1031,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
891
1031
 
892
1032
 
893
1033
  @decorator_knowngood
894
- def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
895
- step: Tensor, eps: Tensor):
1034
+ def _compilable_adam_(
1035
+ exp_avg: List[Tensor],
1036
+ exp_avg_sq: List[Tensor],
1037
+ grad: List[Tensor],
1038
+ beta1: Tensor,
1039
+ beta2: Tensor,
1040
+ step: Tensor,
1041
+ eps: Tensor,
1042
+ ):
896
1043
  beta1 = beta_debias(beta1, step)
897
1044
  beta2 = beta_debias(beta2, step)
898
1045
 
@@ -903,8 +1050,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
903
1050
  copy_stochastic_list_(grad, u32)
904
1051
 
905
1052
 
906
- def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
907
- eps: float = 1e-8):
1053
+ def adam_(
1054
+ exp_avg: List[Tensor],
1055
+ exp_avg_sq: List[Tensor],
1056
+ grad: List[Tensor],
1057
+ beta1: float,
1058
+ beta2: float,
1059
+ step: int,
1060
+ eps: float = 1e-8,
1061
+ ):
908
1062
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
909
1063
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
910
1064
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -912,9 +1066,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
912
1066
 
913
1067
 
914
1068
  @decorator_knowngood
915
- def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
916
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
917
- eps: Tensor, caution: bool):
1069
+ def _fused_compilable_adam_(
1070
+ y: List[Tensor],
1071
+ exp_avg: List[Tensor],
1072
+ exp_avg_sq: List[Tensor],
1073
+ update: List[Tensor],
1074
+ grad: List[Tensor],
1075
+ beta1: Tensor,
1076
+ beta2: Tensor,
1077
+ step: Tensor,
1078
+ decay: Tensor,
1079
+ lr: Tensor,
1080
+ eps: Tensor,
1081
+ caution: bool,
1082
+ ):
918
1083
  beta1 = beta_debias(beta1, step)
919
1084
  beta2 = beta_debias(beta2, step)
920
1085
 
@@ -925,17 +1090,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
925
1090
  _compilable_update_(y, u32, decay, lr, caution, g32)
926
1091
 
927
1092
 
928
- def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
929
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
930
- caution: bool):
1093
+ def fused_adam_(
1094
+ y: List[Tensor],
1095
+ exp_avg: List[Tensor],
1096
+ exp_avg_sq: List[Tensor],
1097
+ update: List[Tensor],
1098
+ grad: List[Tensor],
1099
+ beta1: float,
1100
+ beta2: float,
1101
+ step: int,
1102
+ lr: float,
1103
+ eps: float,
1104
+ decay: float,
1105
+ caution: bool,
1106
+ ):
931
1107
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
932
1108
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
933
1109
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
934
1110
 
935
1111
 
936
1112
  @decorator_knowngood
937
- def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
938
- beta2: Tensor, step: Tensor, eps: Tensor):
1113
+ def _compilable_laprop_(
1114
+ exp_avg: List[Tensor],
1115
+ exp_avg_sq: List[Tensor],
1116
+ grad: List[Tensor],
1117
+ beta1: Tensor,
1118
+ beta2: Tensor,
1119
+ step: Tensor,
1120
+ eps: Tensor,
1121
+ ):
939
1122
  beta1 = beta_debias(beta1, step)
940
1123
  beta2 = beta_debias(beta2, step)
941
1124
 
@@ -946,8 +1129,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
946
1129
  copy_stochastic_list_(grad, gp32)
947
1130
 
948
1131
 
949
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
950
- eps: float = 1e-8):
1132
+ def laprop_(
1133
+ exp_avg: List[Tensor],
1134
+ exp_avg_sq: List[Tensor],
1135
+ grad: List[Tensor],
1136
+ beta1: float,
1137
+ beta2: float,
1138
+ step: int,
1139
+ eps: float = 1e-8,
1140
+ ):
951
1141
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
952
1142
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
953
1143
  _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -955,9 +1145,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
955
1145
 
956
1146
 
957
1147
  @decorator_knowngood
958
- def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
959
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
960
- caution: bool, eps: Tensor):
1148
+ def _fused_compilable_laprop_(
1149
+ y: List[Tensor],
1150
+ exp_avg: List[Tensor],
1151
+ exp_avg_sq: List[Tensor],
1152
+ update: List[Tensor],
1153
+ grad: List[Tensor],
1154
+ beta1: Tensor,
1155
+ beta2: Tensor,
1156
+ step: Tensor,
1157
+ lr: Tensor,
1158
+ decay: Tensor,
1159
+ caution: bool,
1160
+ eps: Tensor,
1161
+ ):
961
1162
  beta1 = beta_debias(beta1, step)
962
1163
  beta2 = beta_debias(beta2, step)
963
1164
 
@@ -968,9 +1169,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
968
1169
  _compilable_update_(y, u32, decay, lr, caution, gp32)
969
1170
 
970
1171
 
971
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
972
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
973
- eps: float = 1e-8):
1172
+ def fused_laprop_(
1173
+ y: List[Tensor],
1174
+ exp_avg: List[Tensor],
1175
+ exp_avg_sq: List[Tensor],
1176
+ update: List[Tensor],
1177
+ grad: List[Tensor],
1178
+ beta1: float,
1179
+ beta2: float,
1180
+ step: int,
1181
+ lr: float,
1182
+ decay: float,
1183
+ caution: bool,
1184
+ eps: float = 1e-8,
1185
+ ):
974
1186
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
975
1187
  beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
976
1188
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
@@ -978,7 +1190,7 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
978
1190
 
979
1191
  @decorator_knowngood
980
1192
  def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
981
- u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
1193
+ u32, g32, exp_avg_sq32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq]]
982
1194
  _compilable_update_(y, u32, decay, lr, caution, g32)
983
1195
 
984
1196
  beta1 = beta_debias(beta1, step)
@@ -997,7 +1209,7 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
997
1209
 
998
1210
  @decorator_knowngood
999
1211
  def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
1000
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
1212
+ g32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
1001
1213
  update = [e.clone() for e in exp_avg]
1002
1214
 
1003
1215
  beta1 = beta_debias(beta1, step)
@@ -1044,8 +1256,9 @@ def copy_stochastic_(target: Tensor, source: Tensor):
1044
1256
 
1045
1257
 
1046
1258
  @decorator_knowngood
1047
- def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
1048
- g: List[Optional[Tensor]]):
1259
+ def _compilable_update_(
1260
+ p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]]
1261
+ ):
1049
1262
  for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
1050
1263
  u_ = promote(u_.view_as(p_))
1051
1264
  p32_ = promote(p_)
@@ -1055,8 +1268,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
1055
1268
  copy_stochastic_(p_, p32_)
1056
1269
 
1057
1270
 
1058
- def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
1059
- grad: List[Tensor] = None):
1271
+ def update_param_(
1272
+ param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None
1273
+ ):
1060
1274
  param, update, grad = list_guard(param, update, grad)
1061
1275
  lr = scalar_guard(lr, param[0])
1062
1276
  if not caution:
@@ -1064,38 +1278,117 @@ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: f
1064
1278
  _compilable_update_(param, update, decay, lr, caution, grad)
1065
1279
 
1066
1280
 
1067
- def precond_schedule(step, precond_scheduler, rng):
1281
+ def precond_schedule(step, precond_scheduler):
1068
1282
  precond_prob = max(step, 1) ** precond_scheduler[0]
1069
1283
  precond_prob = math.log10(precond_prob)
1070
1284
  precond_prob = precond_prob ** precond_scheduler[1] + 1
1071
- precond_prob = 1 / precond_prob
1072
- update_precond = rng.random() < precond_prob
1073
- return update_precond
1285
+ return 1 / precond_prob
1074
1286
 
1075
1287
 
1076
1288
  def get_soap_precond_schedule(precond_scheduler):
1077
- rng = random.Random(0x12312)
1078
-
1079
- def _inner(step):
1080
- return precond_schedule(step, precond_scheduler, rng)
1081
-
1082
- return _inner
1289
+ return functools.partial(precond_schedule, precond_scheduler=precond_scheduler)
1083
1290
 
1084
1291
 
1085
1292
  def _max_idx(x: List[int]):
1086
1293
  return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
1087
1294
 
1088
1295
 
1089
- def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
1090
- """For a scalar or tensor t, we initialize its preconditioner Q and
1296
+ @decorator_knowngood
1297
+ def stable_exp(x: Tensor):
1298
+ # fp16:
1299
+ # exp(x) is stable in [-17, 11]
1300
+ # `stable_exp` extends to [-17, 17]
1301
+ # average error (in [-10, 10]) increased from 2.288e-3 to 2.299e-3
1302
+ # fp32:
1303
+ # exp(x) is stable in [-103, 88]
1304
+ # `stable_exp` extends to [-103, 103]
1305
+ # average error (in [-87, 87]) reduced from 3.309-06 to 3.224-06
1306
+ return torch.where(x > 0, 1 / (-x).exp(), x.exp())
1307
+
1308
+
1309
+ @decorator_knowngood
1310
+ def mean_root(x: torch.Tensor, pow: float, eps=1e-12):
1311
+ # 1 / (mean(x ** pow) ** (1 / pow / 2))
1312
+ log_x = x.double().abs().clamp(min=eps).log()
1313
+ log_mean_x_pow = (log_x * pow).logsumexp(dim=0) - math.log(x.numel())
1314
+ return stable_exp(-log_mean_x_pow / pow / 2)
1315
+
1316
+
1317
+ @decorator_knowngood
1318
+ def divided_root(x: torch.Tensor, y: torch.Tensor, pow0: float, pow1: float, eps=1e-12):
1319
+ # mean(x ** pow0) ** (1 / pow0 / 2) / mean(y ** pow1) ** (1 / pow1 / 2)
1320
+ log_x = x.double().abs().clamp(min=eps).log()
1321
+ log_y = y.double().abs().clamp(min=eps).log()
1322
+
1323
+ x_normed = (log_x * pow0).logsumexp(dim=0) - math.log(x.numel())
1324
+ x_normed = x_normed / pow0 / 2
1325
+
1326
+ y_normed = (log_y * pow1).logsumexp(dim=0) - math.log(y.numel())
1327
+ y_normed = y_normed / pow1 / 2
1328
+
1329
+ return stable_exp(x_normed - y_normed)
1330
+
1331
+
1332
+ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector, scale_max: float = 1e6):
1333
+ automatic_scale = True
1334
+ manual_hint = " Set it manually using `precond_init_scale=0.1`"
1335
+ if scale is not None:
1336
+ automatic_scale = False
1337
+ warn_once(
1338
+ "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics."
1339
+ )
1340
+ if scale_scale is not None and scale_scale != 1:
1341
+ warn_once(
1342
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale."
1343
+ )
1344
+ elif hessian_vector is None:
1345
+ scale = mean_root(grad, 4) * scale_scale
1346
+ else:
1347
+ scale = divided_root(vector, hessian_vector, 2, 4) * scale_scale
1348
+ if isinstance(scale, torch.Tensor):
1349
+ scale = scale.item() # slow, but necessary
1350
+ if np.isfinite(scale):
1351
+ if scale > scale_max or scale < 1 / scale_max:
1352
+ warn_once(f"The computed precond_init_scale {scale} is outside of the expected range.{manual_hint}")
1353
+ return scale
1354
+ if not automatic_scale:
1355
+ raise ValueError("The manually set precond_init_scale is not finite")
1356
+
1357
+ for x in (grad, hessian_vector, vector):
1358
+ if x is None:
1359
+ continue
1360
+ if torch.allclose(x, torch.zeros_like(x)).item():
1361
+ raise ValueError(f"Grad or HVP is all 0s, causing NaNs in precond_init_scale computation.{manual_hint}")
1362
+ if not torch.isfinite(x).all().item():
1363
+ raise ValueError("Grad or HVP is not finite")
1364
+ raise ValueError(f"Computed precond_init_scale is not finite.{manual_hint}")
1365
+
1366
+
1367
+ def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
1368
+ scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1369
+ U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1370
+ V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1371
+ d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
1372
+ return U, V, d
1373
+
1374
+
1375
+ def init_Q_exprs(
1376
+ grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector, dtype=None
1377
+ ):
1378
+ """
1379
+ For a scalar or tensor `grad`, we initialize its preconditioner Q and
1091
1380
  reusable einsum expressions for updating Q and preconditioning gradient.
1381
+
1382
+ precond init scale computation from
1383
+ https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
1092
1384
  """
1385
+ scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1093
1386
  letters = string.ascii_lowercase + string.ascii_uppercase
1094
- dtype = dtype if dtype is not None else t.dtype
1095
- shape = t.shape
1387
+ dtype = dtype if dtype is not None else grad.dtype
1388
+ shape = grad.shape
1096
1389
 
1097
1390
  if len(shape) == 0: # scalar
1098
- Q = [scale * torch.ones_like(t, dtype=dtype)]
1391
+ Q = [scale * torch.ones_like(grad, dtype=dtype)]
1099
1392
  exprA = ",->"
1100
1393
  exprGs = [",->"]
1101
1394
  exprP = ",,->"
@@ -1103,7 +1396,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1103
1396
 
1104
1397
  # Tensor
1105
1398
  if len(shape) > 13:
1106
- raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!")
1399
+ raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
1107
1400
 
1108
1401
  scale = scale ** (1 / len(shape))
1109
1402
 
@@ -1119,8 +1412,10 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1119
1412
  elif memory_save_mode == "all_diag":
1120
1413
  dim_diag = [True for _ in shape]
1121
1414
  else:
1122
- raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1123
- "[None, 'one_diag', 'all_diag', 'smart_one_diag']")
1415
+ raise ValueError(
1416
+ f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1417
+ "[None, 'one_diag', 'all_diag', 'smart_one_diag']"
1418
+ )
1124
1419
 
1125
1420
  Q = []
1126
1421
  piece1A, piece2A, piece3A = ([], "", "")
@@ -1129,7 +1424,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1129
1424
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
1130
1425
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
1131
1426
  # use diagonal matrix as preconditioner for this dim
1132
- Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
1427
+ Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
1133
1428
 
1134
1429
  piece1A.append(letters[i])
1135
1430
  piece2A = piece2A + letters[i]
@@ -1143,13 +1438,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1143
1438
  piece4P = piece4P + letters[i + 13]
1144
1439
  else:
1145
1440
  # use triangular matrix as preconditioner for this dim
1146
- Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
1441
+ Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
1147
1442
  piece1A.append(letters[i] + letters[i + 13])
1148
1443
  piece2A = piece2A + letters[i + 13]
1149
1444
  piece3A = piece3A + letters[i]
1150
1445
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1151
1446
  piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1152
- subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
1447
+ subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
1153
1448
  exprGs.append(subscripts)
1154
1449
  a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1155
1450
  piece1P.append(a + b)
@@ -1158,7 +1453,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1158
1453
  piece4P = piece4P + b
1159
1454
 
1160
1455
  exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
1161
- exprP = (",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P)
1456
+ exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
1162
1457
  return [Q, (exprA, tuple(exprGs), exprP)]
1163
1458
 
1164
1459
 
@@ -1170,37 +1465,207 @@ def psgd_balance_Q(Q_in):
1170
1465
  torch._foreach_mul_(Q_in, list(norms))
1171
1466
 
1172
1467
 
1173
- def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1174
- eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
1175
- eps *= G.norm() / G.numel()
1176
- G = G + torch.randn_like(G) * eps
1177
- md = min_dtype(Q + [G])
1178
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1179
- order = G.dim()
1180
- if V is None:
1181
- conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1468
+ @decorator
1469
+ def psgd_balance_lra(U: Tensor, V: Tensor):
1470
+ u_norm = promote(torch.linalg.vector_norm(U))
1471
+ v_norm = promote(torch.linalg.vector_norm(V))
1472
+ scale = (u_norm / v_norm) ** 0.5
1473
+ U.div_(scale)
1474
+ V.mul_(scale)
1475
+
1476
+
1477
+ @decorator
1478
+ def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1479
+ dtype = min_dtype([U, V, x])
1480
+ return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1481
+
1482
+
1483
+ def update_lra_precond_(
1484
+ U: List[Tensor],
1485
+ V: List[Tensor],
1486
+ d: List[Tensor],
1487
+ vector: Tensor,
1488
+ hessian_vector: Tensor,
1489
+ eps: float,
1490
+ step: float,
1491
+ delayed: bool,
1492
+ ):
1493
+ """
1494
+ Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1495
+ """
1496
+ U_orig, V_orig, d_orig = U, V, d
1497
+
1498
+ U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
1499
+
1500
+ dtype = min_dtype([U, V, vector, hessian_vector])
1501
+ U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
1502
+
1503
+ eps = scalar_guard(eps, vector)
1504
+
1505
+ Qh = low_rank_mm(U, V, d * hessian_vector)
1506
+ Ph = d * low_rank_mm(V, U, Qh)
1507
+ rank = U.size(1)
1508
+
1509
+ VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
1510
+ I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
1511
+ IpVtU = I + VtU
1512
+ invQtv = vector / d
1513
+
1514
+ # LU factorization to reuse computation
1515
+ try:
1516
+ LU, pivots = torch.linalg.lu_factor(IpVtU)
1517
+ except RuntimeError:
1518
+ # Error:
1519
+ # U[2,2] is zero and using it on lu_solve would result in a division by zero.
1520
+ # If you still want to perform the factorization, consider calling
1521
+ # linalg.lu(A, pivot) or linalg.lu_factor_ex(A, pivot)
1522
+ # ---
1523
+ # So, we skip this step and reattempt on the next one
1524
+ return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1525
+
1526
+ invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1527
+ invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1528
+ invPv = invPv / d
1529
+
1530
+ nablaD = Ph * hessian_vector - vector * invPv
1531
+ divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
1532
+ divisor = divisor.add(eps).sqrt().max()
1533
+ d_step = step / divisor
1534
+
1535
+ apply_flat_add(d_orig, d * nablaD, -d_step)
1536
+
1537
+ a, b = Qh, invQtv
1538
+
1539
+ precond_u = random.random() < 0.5 # update either U or V, not both at the same time
1540
+ precond = V if precond_u else U
1541
+ atV = torch.einsum("b,br->r", a, precond) # o == one
1542
+ btV = torch.einsum("b,br->r", b, precond)
1543
+ atVVt = torch.einsum("r,br->b", atV, precond)
1544
+ btVVt = torch.einsum("r,br->b", btV, precond)
1545
+ precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
1546
+ if precond_u:
1547
+ a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
1548
+ b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
1182
1549
  else:
1183
- conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
1184
- Q = [promote(q) for q in Q]
1550
+ a = a + torch.einsum("br,r->b", V, atV)
1551
+ b = b + torch.einsum("br,r->b", V, btV)
1552
+ a = torch.einsum("b,r->br", a, atV)
1553
+ b = torch.einsum("b,r->br", b, btV)
1554
+ apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
1555
+
1556
+ if not delayed:
1557
+ stochastic_add_([d], [d * nablaD], -d_step)
1558
+ stochastic_add_([U if precond_u else V], [b - a], precond_step)
1559
+ return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1560
+
1561
+
1562
+ def lra_precond(U, V, d, g):
1563
+ """
1564
+ As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
1565
+ """
1566
+ g = low_rank_mm(U, V, d * g)
1567
+ return d * low_rank_mm(V, U, g)
1568
+
1569
+
1570
+ @decorator_knowngood
1571
+ def dampen_grad(g: Tensor, damp: float = 2**-13):
1572
+ # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
1573
+ v = torch.randn_like(g)
1574
+ return v, g + damp * g.abs().mean() * v
1575
+
1576
+
1577
+ @decorator_knowngood
1578
+ def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
1579
+ update = lra_precond(U, V, d, update)
1580
+ start = 0
1581
+ update = update.flatten()
1582
+ for p in params:
1583
+ size = p.numel()
1584
+ copy_stochastic_(p, update[start : start + size].view_as(p))
1585
+ start += size
1586
+
1587
+
1588
+ @decorator_knowngood
1589
+ def apply_flat_update(params: List[Tensor], update: Tensor):
1590
+ start = 0
1591
+ update = update.flatten()
1592
+ for p in params:
1593
+ size = p.numel()
1594
+ copy_stochastic_(p, update[start : start + size].view_as(p))
1595
+ start += size
1596
+
1597
+
1598
+ @decorator_knowngood
1599
+ def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1600
+ start = 0
1601
+ update = update.flatten()
1602
+ for p in params:
1603
+ size = p.numel()
1604
+ stochastic_add_([p], [update[start : start + size].view_as(p)], alpha)
1605
+ start += size
1606
+
1607
+
1608
+ @decorator_knowngood
1609
+ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1610
+ start = 0
1611
+ outputs = []
1612
+ update = update.flatten()
1613
+ for p in params:
1614
+ size = p.numel()
1615
+ outputs.append(update[start : start + size].view_as(p))
1616
+ start += size
1617
+ return outputs
1618
+
1619
+
1620
+ @decorator_knowngood
1621
+ def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1622
+ last_dim = x[0].shape[-remaining:] if remaining else []
1623
+ return torch.cat([i.reshape(-1, *last_dim) for i in x], 0)
1624
+
1625
+
1626
+ @decorator_knowngood
1627
+ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
1628
+ vs = []
1629
+ gs = []
1630
+ for g_ in g:
1631
+ v, g = dampen_grad(g_, damp)
1632
+ vs.append(v)
1633
+ gs.append(g)
1634
+ return flatten(vs), flatten(gs)
1635
+
1636
+
1637
+ @decorator_knowngood
1638
+ def casted_einsum(expr: str, *args: Tensor) -> Tensor:
1639
+ md = min_dtype(args)
1640
+ return torch.einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
1641
+
1642
+
1643
+ def psgd_calc_A_and_conjB(exprA, G, Q, conjB): # conjB ("V", "vector") == randn during hvp/whitening
1644
+ order = G.dim()
1645
+ if order > 1:
1646
+ conjB = conjB.view_as(G).permute(*range(1, order), 0)
1647
+ conjB = conjB.to(promote(G.dtype))
1648
+ A = casted_einsum(exprA, *Q, G)
1185
1649
  for i, q in enumerate(Q):
1650
+ q = promote(q)
1186
1651
  if q.dim() <= 1:
1187
1652
  conjB /= q
1188
1653
  else:
1189
- conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(
1190
- conjB)
1654
+ solved = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)).contiguous(), upper=True, left=False)
1655
+ conjB = solved.reshape_as(conjB)
1191
1656
  if i < order - 1:
1192
- conjB = torch.transpose(conjB, i, order - 1)
1657
+ conjB = conjB.transpose(i, -1)
1193
1658
  return A, conjB
1194
1659
 
1195
1660
 
1196
1661
  def psgd_lb(A, max_abs):
1197
1662
  A /= max_abs
1198
- a0 = torch.einsum('ij,ij->j', A, A)
1663
+ a0 = torch.einsum("ij,ij->j", A, A)
1199
1664
  i = torch.argmax(a0)
1200
1665
  x = torch.index_select(A, 1, i).flatten().contiguous()
1201
- x = torch.einsum('i,ij->j', x, A)
1666
+ x = torch.einsum("i,ij->j", x, A)
1202
1667
  x /= x.norm()
1203
- x = torch.einsum('j,kj->k', x, A)
1668
+ x = torch.einsum("j,kj->k", x, A)
1204
1669
  x = x.norm()
1205
1670
  x *= max_abs
1206
1671
  return x
@@ -1217,7 +1682,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1217
1682
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1218
1683
  term1, term2 = term1 - term2, term1 + term2
1219
1684
  term1 *= precond_lr
1220
- norm = term2.norm(float('inf'))
1685
+ norm = term2.norm(float("inf"))
1221
1686
  if q.dim() < 2:
1222
1687
  term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
1223
1688
  else:
@@ -1225,9 +1690,12 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1225
1690
  term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1226
1691
  term1 = torch.mm(term1, q.to(term1.dtype))
1227
1692
  if store_triu_as_line:
1228
- term1 = triu_to_line([term1])[0][1]
1229
- o = o[1]
1230
- stochastic_add_(o, term1, -1)
1693
+ term1 = triu_to_line([term1])[0][1] # Convert update to line format
1694
+ # Apply update directly to the tensor part of the state tuple o[1]
1695
+ stochastic_add_(o[1], term1, -1)
1696
+ else:
1697
+ # Apply update to the state tensor o
1698
+ stochastic_add_(o, term1, -1)
1231
1699
 
1232
1700
 
1233
1701
  @decorator_knowngood
@@ -1245,7 +1713,7 @@ def l2_normalization_(x, clip_at: float = 1e-8):
1245
1713
  return _compilable_l2_clip_(x, clip_at)
1246
1714
 
1247
1715
 
1248
- def l2_clip_(x, clip_at: float = 1.):
1716
+ def l2_clip_(x, clip_at: float = 1.0):
1249
1717
  x = list_guard(x)
1250
1718
  return _compilable_l2_clip_(x, clip_at)
1251
1719
 
@@ -1437,12 +1905,13 @@ def warn_once(msg):
1437
1905
  _warned.add(msg)
1438
1906
 
1439
1907
 
1440
- def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1441
- name: str = 'cumulative_prob'):
1442
- group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1908
+ def psgd_should_update(
1909
+ group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
1910
+ ):
1911
+ group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
1443
1912
  if not isinstance(prob, float):
1444
- prob = prob(group[f'{name}_prob_step'])
1445
- if group['stochastic_schedule']:
1913
+ prob = prob(group[f"{name}_prob_step"])
1914
+ if group["stochastic_schedule"]:
1446
1915
  return rng.random() < prob
1447
1916
  cumulative_prob = group.get(name, 0)
1448
1917
  group[name] = cumulative_prob + prob
@@ -1450,8 +1919,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1450
1919
 
1451
1920
 
1452
1921
  @decorator_knowngood
1453
- def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1454
- cast: bool = True):
1922
+ def precond_grad_cached_(
1923
+ expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True
1924
+ ):
1455
1925
  if caution:
1456
1926
  ea = _compilable_cautioning(grad, ea)
1457
1927
  md = min_dtype(list(cached_q) + [ea])
@@ -1564,18 +2034,86 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
1564
2034
 
1565
2035
 
1566
2036
  def merge_group(group, *tensors):
1567
- if not group.get('merge_dims', False):
2037
+ if not group.get("merge_dims", False):
1568
2038
  return tensors
1569
2039
  if isinstance(tensors[0], list):
1570
2040
  return [merge_group(group, *t) for t in tensors]
1571
2041
 
1572
2042
  out = []
1573
2043
  for t in tensors:
1574
- append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
1575
- 'max_precond_dim'], group.get('split', False)))
2044
+ append_or_extend(
2045
+ out,
2046
+ dim_merger(
2047
+ t,
2048
+ group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
2049
+ group.get("split", False),
2050
+ ),
2051
+ )
1576
2052
  return out
1577
2053
 
1578
2054
 
2055
+ @decorator_knowngood
2056
+ def _compilable_d_adapt_(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
2057
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2058
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2059
+ next_d = d * (g * s).sum()
2060
+ s = s + u * d
2061
+ next_d = next_d / s.abs().sum()
2062
+ next_d = torch.maximum(next_d, d)
2063
+ copy_stochastic_(u_, u * d)
2064
+ copy_stochastic_(d_, next_d)
2065
+ copy_stochastic_(s_, s)
2066
+
2067
+
2068
+ def d_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor]):
2069
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2070
+ _compilable_d_adapt_(grads, update, state, delta)
2071
+
2072
+
2073
+ @decorator_knowngood
2074
+ def _compilable_lr_adapt_(
2075
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
2076
+ ):
2077
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2078
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2079
+ lr_grad = d.sigmoid()
2080
+ lr_grad = lr_grad * (1 - lr_grad)
2081
+ lr_grad = lr_grad * (s * g).mean()
2082
+ d = d - lr_grad * lr_lr
2083
+ copy_stochastic_(d_, d)
2084
+ copy_stochastic_(u_, u * d.sigmoid())
2085
+ copy_stochastic_(s_, u)
2086
+
2087
+
2088
+ def lr_adaptation(grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float):
2089
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2090
+ lr_lr = scalar_guard(lr_lr, grads[0])
2091
+ _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
2092
+
2093
+
2094
+ @decorator_knowngood
2095
+ def _compilable_pointwise_lr_adapt_(
2096
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: Tensor
2097
+ ):
2098
+ for g_, u_, s_, d_ in zip(grads, update, state, delta):
2099
+ g, u, s, d = promote(g_), promote(u_), promote(s_), promote(d_)
2100
+ lr_grad = d.sigmoid()
2101
+ lr_grad = lr_grad * (1 - lr_grad)
2102
+ lr_grad = lr_grad * s * g
2103
+ d = d - lr_grad * lr_lr
2104
+ copy_stochastic_(d_, d)
2105
+ copy_stochastic_(u_, u * d.sigmoid())
2106
+ copy_stochastic_(s_, u)
2107
+
2108
+
2109
+ def pointwise_lr_adaptation(
2110
+ grads: List[Tensor], update: List[Tensor], state: List[Tensor], delta: List[Tensor], lr_lr: float
2111
+ ):
2112
+ grads, update, state, delta = list_guard(grads, update, state, delta)
2113
+ lr_lr = scalar_guard(lr_lr, grads[0])
2114
+ _compilable_lr_adapt_(grads, update, state, delta, lr_lr)
2115
+
2116
+
1579
2117
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1580
2118
  optimizers = {}
1581
2119
 
@@ -1598,8 +2136,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1598
2136
 
1599
2137
  o = optimizer(parameters, *args, **kwargs)
1600
2138
  step_fn = o.step
1601
- o.step = functools.partial(warn_once,
1602
- msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
2139
+ o.step = functools.partial(
2140
+ warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything."
2141
+ )
1603
2142
 
1604
2143
  def _step(p: Tensor):
1605
2144
  seen_params.add(p)