heavyball 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/utils.py CHANGED
@@ -4,7 +4,7 @@ import math
4
4
  import random
5
5
  import string
6
6
  import warnings
7
- from typing import List, Optional, Tuple, Callable, Union
7
+ from typing import Callable, List, Optional, Tuple, Union
8
8
  from unittest.mock import patch
9
9
 
10
10
  import numpy as np
@@ -17,25 +17,18 @@ from torch.utils._pytree import tree_map
17
17
 
18
18
  config.cache_size_limit = 2 ** 16
19
19
 
20
- np.warnings = warnings
21
-
22
20
  compile_mode = "max-autotune-no-cudagraphs"
23
21
  dynamic = False
24
22
  compile_mode_recommended_to_none = None
25
- zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
23
+ zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster
26
24
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
27
25
 
28
- base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
29
- 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
30
- 'weight_decay': 1e-4}
31
-
32
26
 
33
27
  def decorator(func):
34
28
  compiled = None
35
29
 
36
30
  @functools.wraps(func)
37
31
  def _fn(*args, **kwargs):
38
- disable = compile_mode_recommended_to_none is None
39
32
  if is_compiling() or compile_mode_recommended_to_none is None:
40
33
  return func(*args, **kwargs)
41
34
  nonlocal compiled
@@ -66,7 +59,7 @@ einsum_base = string.ascii_lowercase
66
59
 
67
60
  @decorator_knowngood
68
61
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
69
- beta1: Tensor, decay: float, grad: List[Tensor], caution):
62
+ beta1: Tensor, decay: float, grad: List[Tensor], caution, ):
70
63
  for op, oz, u_, g_ in zip(p, z, update, grad):
71
64
  u_ = u_.view_as(op)
72
65
  p_, z_, u_ = map(promote, (op, oz, u_))
@@ -82,8 +75,8 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u
82
75
 
83
76
 
84
77
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
85
- z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0,
86
- step: int = 0, decay: float = 0.0):
78
+ z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0, step: int = 0,
79
+ decay: float = 0.0, ):
87
80
  weight = abs(lr) ** weight_lr_power * max(step, 1) ** r
88
81
  weight_sum = weight_sum + weight
89
82
 
@@ -165,7 +158,7 @@ def eps_sqrt(item, eps):
165
158
 
166
159
  @decorator_knowngood
167
160
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
168
- out: List[Optional[Tensor]]):
161
+ out: List[Optional[Tensor]]):
169
162
  g32 = promote(grad)
170
163
  s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2)
171
164
 
@@ -227,7 +220,7 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
227
220
 
228
221
 
229
222
  def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
230
- minimum: float = 1e-3, eps: float = 1e-8):
223
+ minimum: float = 1e-3, eps: float = 1e-8):
231
224
  if clip_val <= 0:
232
225
  return gradients
233
226
  parameters, gradients = list_guard(parameters, gradients)
@@ -253,23 +246,22 @@ def clean():
253
246
 
254
247
 
255
248
  def _ignore_warning(msg):
256
- warnings.filterwarnings('ignore', f'.*{msg}.*')
249
+ warnings.filterwarnings("ignore", f".*{msg}.*")
257
250
 
258
251
 
259
- def set_torch(benchmark_limit: int = 32):
252
+ def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"):
260
253
  cudnn.benchmark = True
261
254
  cudnn.deterministic = False
262
255
  cudnn.benchmark_limit = benchmark_limit
263
256
  torch.use_deterministic_algorithms(False)
264
257
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
265
- opt_einsum.enabled = False
266
- opt_einsum.strategy = "auto"
258
+ opt_einsum.set_flags(True, einsum_strategy)
267
259
 
268
260
  # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
269
261
  _ignore_warning(
270
- 'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
262
+ "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak")
271
263
  _ignore_warning(
272
- 'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
264
+ "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak")
273
265
 
274
266
 
275
267
  @decorator
@@ -277,7 +269,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
277
269
  assert len(G.shape) == 2
278
270
  a, b, c = (3.4445, -4.7750, 2.0315)
279
271
  X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
280
- X /= (X.norm() + eps) # ensure top singular value <= 1
272
+ X /= X.norm() + eps # ensure top singular value <= 1
281
273
  if G.size(0) > G.size(1):
282
274
  X = X.T
283
275
  for _ in range(steps):
@@ -290,10 +282,10 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
290
282
 
291
283
 
292
284
  def ortho(x):
293
- if zeroth_power_mode == 'qr':
285
+ if zeroth_power_mode == "qr":
294
286
  return torch.linalg.qr(x).Q
295
- if zeroth_power_mode == 'svd':
296
- u, s, v = torch.linalg.svd(x)
287
+ if zeroth_power_mode == "svd":
288
+ u, _s, v = torch.linalg.svd(x)
297
289
  return u @ v.T
298
290
  raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
299
291
 
@@ -351,12 +343,12 @@ def _compilable_grafting(magnitude, direction):
351
343
 
352
344
  @decorator_knowngood
353
345
  def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
354
- if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
346
+ if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
355
347
  y = zeropower_via_newtonschulz5(x, 5)
356
- elif mode == 'qr':
348
+ elif mode == "qr":
357
349
  y = torch.linalg.qr(promote(x)).Q
358
- elif mode == 'svd':
359
- u, s, v = torch.linalg.svd(promote(x))
350
+ elif mode == "svd":
351
+ u, _s, v = torch.linalg.svd(promote(x))
360
352
  y = u @ v.T
361
353
  else:
362
354
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
@@ -403,7 +395,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
403
395
  q_old = promote(q.data)
404
396
 
405
397
  tmp = m @ q_old
406
- est_eig = torch.einsum('ij,ij->j', q_old, tmp)
398
+ est_eig = torch.einsum("ij,ij->j", q_old, tmp)
407
399
  sort_idx = torch.argsort(est_eig, descending=True)
408
400
 
409
401
  tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
@@ -415,19 +407,19 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona
415
407
  return
416
408
 
417
409
  assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
418
- in_str = einsum_base[:exp_avg.dim()]
419
- out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
410
+ in_str = einsum_base[: exp_avg.dim()]
411
+ out_str = einsum_base[exp_avg.dim(): 2 * exp_avg.dim()]
420
412
 
421
413
  from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None])
422
414
  if not from_shampoo:
423
415
  return
424
416
 
425
- to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
426
- out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
417
+ to_shampoo = ",".join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None])
418
+ out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
427
419
 
428
- subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
420
+ subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}"
429
421
  exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None],
430
- *[q for q in new_qs if q is not None])
422
+ *[q for q in new_qs if q is not None])
431
423
  copy_stochastic_(exp_avg, exp_avg_new)
432
424
 
433
425
  for q, q_new in zip(Q, new_qs):
@@ -453,11 +445,11 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
453
445
  while True:
454
446
  try:
455
447
  eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
456
- eigval, eigvec = torch.linalg.eigh(m + eps * eye)
448
+ _eigval, eigvec = torch.linalg.eigh(m + eps * eye)
457
449
  eigvec = eigvec.to(device=device, dtype=dtype)
458
450
  break
459
451
  except torch.OutOfMemoryError:
460
- if m.device.type == 'cpu':
452
+ if m.device.type == "cpu":
461
453
  raise
462
454
  else:
463
455
  m = m.cpu()
@@ -489,21 +481,21 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
489
481
 
490
482
  def get_beta1(group):
491
483
  beta = None
492
- if 'beta' in group:
493
- beta = group['beta']
494
- if beta is None and 'betas' in group:
495
- beta = group['betas'][0]
484
+ if "beta" in group:
485
+ beta = group["beta"]
486
+ if beta is None and "betas" in group:
487
+ beta = group["betas"][0]
496
488
  if beta is None:
497
489
  raise ValueError("Beta not found in group.")
498
490
  return beta
499
491
 
500
492
 
501
493
  def get_beta2(group):
502
- if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
494
+ if "palm" in group and group["palm"] is True and "beta2_scale" in group:
503
495
  step = max(group.get("step", 1), 1)
504
- return 1 - step ** -group['beta2_scale']
505
- if 'betas' in group:
506
- return group['betas'][1]
496
+ return 1 - step ** -group["beta2_scale"]
497
+ if "betas" in group:
498
+ return group["betas"][1]
507
499
  raise ValueError("Beta2 not found in group.")
508
500
 
509
501
 
@@ -580,9 +572,9 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
580
572
  if not isinstance(m, Tensor):
581
573
  continue
582
574
  b = einsum_base[idx]
583
- g0 = einsum_base[:grad.dim()]
575
+ g0 = einsum_base[: grad.dim()]
584
576
  g1 = g0.replace(b, b.upper())
585
- outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
577
+ outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad)
586
578
  stochastic_lerp_(m, outer_product, 1 - beta)
587
579
 
588
580
 
@@ -623,19 +615,19 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
623
615
  """
624
616
  Initializes the preconditioner matrices (L and R in the paper).
625
617
  """
626
- state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
618
+ state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper).
627
619
  if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
628
620
  for sh in grad.shape:
629
621
  if sh > max_precond_dim or sh == 1:
630
622
  # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621
631
- state['GG'].append(None)
623
+ state["GG"].append(None)
632
624
  else:
633
- state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
625
+ state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
634
626
  else:
635
- state['GG'].append(None)
627
+ state["GG"].append(None)
636
628
 
637
- update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
638
- state['Q'] = get_orthogonal_matrix(state['GG'])
629
+ update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0)
630
+ state["Q"] = get_orthogonal_matrix(state["GG"])
639
631
 
640
632
 
641
633
  @decorator
@@ -646,11 +638,11 @@ def project(grad, Q, back: bool):
646
638
  :param back: whether to project to Shampoo eigenbases or back to original space
647
639
  :return:
648
640
  """
649
- param = einsum_base[:grad.dim()]
650
- preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None])
641
+ param = einsum_base[: grad.dim()]
642
+ preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None])
651
643
  if preconditioners:
652
- out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
653
- out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None])
644
+ out = "".join([c.upper() if c.upper() in preconditioners else c for c in param])
645
+ out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None])
654
646
  grad = out.to(grad.dtype)
655
647
  return grad
656
648
 
@@ -667,12 +659,12 @@ def modify_closure(closure):
667
659
  """
668
660
 
669
661
  def patched_backward(self, *args, **kwargs):
670
- kwargs['create_graph'] = True
662
+ kwargs["create_graph"] = True
671
663
  return original_backward(self, *args, **kwargs)
672
664
 
673
665
  original_backward = torch.Tensor.backward
674
666
 
675
- with patch.object(torch.Tensor, 'backward', patched_backward):
667
+ with patch.object(torch.Tensor, "backward", patched_backward):
676
668
  return closure()
677
669
 
678
670
 
@@ -683,6 +675,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
683
675
  The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
684
676
  Further notice that both methods have different numerics outputs
685
677
  """
678
+
686
679
  ema_decay: float = 0.001
687
680
  compile_step: bool = False
688
681
  hessian_approx: bool = False
@@ -691,10 +684,10 @@ class StatefulOptimizer(torch.optim.Optimizer):
691
684
  finite_differences: bool = False
692
685
 
693
686
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
694
- super().__init__(params, {**defaults, 'foreach': foreach})
687
+ super().__init__(params, {**defaults, "foreach": foreach})
695
688
  self.use_ema = use_ema
696
689
  self.mapping = {}
697
- self._inner_group = {'stochastic_schedule': self.stochastic_schedule}
690
+ self._inner_group = {"stochastic_schedule": self.stochastic_schedule}
698
691
  self._precond_rng = random.Random(0x12312)
699
692
  self._is_preconditioning = None
700
693
 
@@ -710,24 +703,24 @@ class StatefulOptimizer(torch.optim.Optimizer):
710
703
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
711
704
  for p, g in zip(p_list, g_list):
712
705
  state = self.state_(p)
713
- if 'mars_old_grad' not in state:
714
- state['mars_old_grad'] = torch.zeros_like(g)
715
- old_gs = [self.state_(p)['mars_old_grad'] for p in p_list]
706
+ if "mars_old_grad" not in state:
707
+ state["mars_old_grad"] = torch.zeros_like(g)
708
+ old_gs = [self.state_(p)["mars_old_grad"] for p in p_list]
716
709
  mars_correction(g_list, old_gs, mars_gamma, beta)
717
710
 
718
711
  def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
719
- beta1: float = -1.0):
712
+ beta1: float = -1.0):
720
713
  for p in group["params"]:
721
714
  if p in self.mapping:
722
715
  p_views = self.mapping[p]
723
716
  else:
724
717
  self.mapping[p] = p_views = merge_group(group, p)
725
718
 
726
- grad = getattr(p, 'grad', None)
719
+ grad = getattr(p, "grad", None)
727
720
  p.grad = None
728
721
 
729
722
  if grad is None:
730
- grad = [getattr(pv, 'grad', None) for pv in p_views]
723
+ grad = [getattr(pv, "grad", None) for pv in p_views]
731
724
  else:
732
725
  grad = merge_group(group, grad)
733
726
 
@@ -736,8 +729,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
736
729
  continue
737
730
  if should_promote:
738
731
  g = promote(g)
739
- if beta1 >= 0 and group.get('mars', False):
740
- self.mars_correct_list(group, [pv], [g], group['mars_gamma'], beta1)
732
+ if beta1 >= 0 and group.get("mars", False):
733
+ self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1)
741
734
  yield pv, g
742
735
 
743
736
  def state_size(self) -> int:
@@ -759,46 +752,46 @@ class StatefulOptimizer(torch.optim.Optimizer):
759
752
  def ema_update(self):
760
753
  with torch.no_grad():
761
754
  for group in self.param_groups:
762
- active_p = [p for p in group['params']]
755
+ active_p = [p for p in group["params"]]
763
756
 
764
757
  if not active_p:
765
758
  return
766
759
 
767
- k = group['ema_step'] = group.get('ema_step', -1) + 1
760
+ k = group["ema_step"] = group.get("ema_step", -1) + 1
768
761
 
769
762
  for p in active_p:
770
- if 'param_ema' not in self.state_(p):
771
- self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
763
+ if "param_ema" not in self.state_(p):
764
+ self.state_(p)["param_ema"] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
772
765
 
773
- y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
766
+ y, param_ema = zip(*[(p.data, self.state_(p)["param_ema"]) for p in active_p])
774
767
  torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
775
768
 
776
769
  def copy_emas_to_params(self):
777
770
  with torch.no_grad():
778
771
  for group in self.param_groups:
779
- active_p = [p for p in group['params']]
772
+ active_p = [p for p in group["params"]]
780
773
 
781
774
  if not active_p:
782
775
  return
783
776
 
784
777
  for p in active_p:
785
- if 'param_ema' in self.state_(p):
778
+ if "param_ema" in self.state_(p):
786
779
  p_clone = p.data.clone()
787
- set_(p.data, self.state_(p)['param_ema'])
788
- set_(self.state_(p)['param_ema'], p_clone)
780
+ set_(p.data, self.state_(p)["param_ema"])
781
+ set_(self.state_(p)["param_ema"], p_clone)
789
782
 
790
783
  def copy_params_to_emas(self):
791
784
  with torch.no_grad():
792
785
  for group in self.param_groups:
793
- active_p = [p for p in group['params']]
786
+ active_p = [p for p in group["params"]]
794
787
 
795
788
  if not active_p:
796
789
  return
797
790
 
798
791
  for p in active_p:
799
- if 'param_ema' in self.state_(p):
800
- ema_clone = self.state_(p)['param_ema'].data.clone()
801
- set_(self.state_(p)['param_ema'], p.data)
792
+ if "param_ema" in self.state_(p):
793
+ ema_clone = self.state_(p)["param_ema"].data.clone()
794
+ set_(self.state_(p)["param_ema"], p.data)
802
795
  set_(p.data, ema_clone)
803
796
 
804
797
  def _handle_closure(self, closure):
@@ -824,7 +817,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
824
817
  grads.append(g)
825
818
  p.vector = torch.randn_like(p)
826
819
  p.orig = p.data.clone()
827
- stochastic_add_(p.data, p.vector, tiny_bf16)
820
+ # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
821
+ stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
828
822
  else:
829
823
  with torch.enable_grad():
830
824
  loss = modify_closure(closure)
@@ -833,6 +827,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
833
827
  with torch.enable_grad():
834
828
  closure()
835
829
 
830
+ # we don't subtract the vector here again to avoid accumulating error from (x + eps - eps + eps - eps)
831
+ # this costs more memory, but the imprecision seems too severe to use the other method
836
832
  for group in self.param_groups:
837
833
  for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
838
834
  p.grad = grads.pop(0)
@@ -845,7 +841,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
845
841
  for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
846
842
  p.grad = g
847
843
  params, grads = zip(*[x for group in self.param_groups for x in
848
- self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
844
+ self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
849
845
  vs = [torch.randn_like(p) for p in params]
850
846
  with torch.enable_grad():
851
847
  hvs = torch.autograd.grad(grads, params, vs)
@@ -867,7 +863,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
867
863
  # we assume that parameters are constant and that there are no excessive recompiles
868
864
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
869
865
  for group in self.param_groups:
870
- group['is_preconditioning'] = self._is_preconditioning
866
+ group["is_preconditioning"] = self._is_preconditioning
871
867
  self._step(group)
872
868
  if self.use_ema:
873
869
  self.ema_update()
@@ -892,7 +888,7 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
892
888
 
893
889
  @decorator_knowngood
894
890
  def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
895
- step: Tensor, eps: Tensor):
891
+ step: Tensor, eps: Tensor, ):
896
892
  beta1 = beta_debias(beta1, step)
897
893
  beta2 = beta_debias(beta2, step)
898
894
 
@@ -904,7 +900,7 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
904
900
 
905
901
 
906
902
  def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
907
- eps: float = 1e-8):
903
+ eps: float = 1e-8, ):
908
904
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
909
905
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
910
906
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -913,8 +909,8 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
913
909
 
914
910
  @decorator_knowngood
915
911
  def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
916
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
917
- eps: Tensor, caution: bool):
912
+ grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
913
+ caution: bool, ):
918
914
  beta1 = beta_debias(beta1, step)
919
915
  beta2 = beta_debias(beta2, step)
920
916
 
@@ -926,8 +922,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
926
922
 
927
923
 
928
924
  def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
929
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
930
- caution: bool):
925
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
926
+ caution: bool, ):
931
927
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
932
928
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
933
929
  _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
@@ -935,7 +931,7 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
935
931
 
936
932
  @decorator_knowngood
937
933
  def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
938
- beta2: Tensor, step: Tensor, eps: Tensor):
934
+ beta2: Tensor, step: Tensor, eps: Tensor, ):
939
935
  beta1 = beta_debias(beta1, step)
940
936
  beta2 = beta_debias(beta2, step)
941
937
 
@@ -947,7 +943,7 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
947
943
 
948
944
 
949
945
  def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
950
- eps: float = 1e-8):
946
+ eps: float = 1e-8, ):
951
947
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
952
948
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
953
949
  _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -956,8 +952,8 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
956
952
 
957
953
  @decorator_knowngood
958
954
  def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
959
- grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
960
- caution: bool, eps: Tensor):
955
+ grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor, caution: bool,
956
+ eps: Tensor, ):
961
957
  beta1 = beta_debias(beta1, step)
962
958
  beta2 = beta_debias(beta2, step)
963
959
 
@@ -969,8 +965,8 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
969
965
 
970
966
 
971
967
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
972
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
973
- eps: float = 1e-8):
968
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
969
+ eps: float = 1e-8, ):
974
970
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
975
971
  beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
976
972
  _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
@@ -978,7 +974,7 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
978
974
 
979
975
  @decorator_knowngood
980
976
  def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
981
- u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
977
+ u32, g32, exp_avg_sq32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq]]
982
978
  _compilable_update_(y, u32, decay, lr, caution, g32)
983
979
 
984
980
  beta1 = beta_debias(beta1, step)
@@ -997,7 +993,7 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e
997
993
 
998
994
  @decorator_knowngood
999
995
  def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps):
1000
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
996
+ g32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
1001
997
  update = [e.clone() for e in exp_avg]
1002
998
 
1003
999
  beta1 = beta_debias(beta1, step)
@@ -1045,7 +1041,7 @@ def copy_stochastic_(target: Tensor, source: Tensor):
1045
1041
 
1046
1042
  @decorator_knowngood
1047
1043
  def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
1048
- g: List[Optional[Tensor]]):
1044
+ g: List[Optional[Tensor]]):
1049
1045
  for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach
1050
1046
  u_ = promote(u_.view_as(p_))
1051
1047
  p32_ = promote(p_)
@@ -1056,7 +1052,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten
1056
1052
 
1057
1053
 
1058
1054
  def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
1059
- grad: List[Tensor] = None):
1055
+ grad: List[Tensor] = None):
1060
1056
  param, update, grad = list_guard(param, update, grad)
1061
1057
  lr = scalar_guard(lr, param[0])
1062
1058
  if not caution:
@@ -1064,38 +1060,70 @@ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: f
1064
1060
  _compilable_update_(param, update, decay, lr, caution, grad)
1065
1061
 
1066
1062
 
1067
- def precond_schedule(step, precond_scheduler, rng):
1063
+ def precond_schedule(step, precond_scheduler):
1068
1064
  precond_prob = max(step, 1) ** precond_scheduler[0]
1069
1065
  precond_prob = math.log10(precond_prob)
1070
1066
  precond_prob = precond_prob ** precond_scheduler[1] + 1
1071
- precond_prob = 1 / precond_prob
1072
- update_precond = rng.random() < precond_prob
1073
- return update_precond
1067
+ return 1 / precond_prob
1074
1068
 
1075
1069
 
1076
1070
  def get_soap_precond_schedule(precond_scheduler):
1077
- rng = random.Random(0x12312)
1078
-
1079
- def _inner(step):
1080
- return precond_schedule(step, precond_scheduler, rng)
1081
-
1082
- return _inner
1071
+ return functools.partial(precond_schedule, precond_scheduler=precond_scheduler)
1083
1072
 
1084
1073
 
1085
1074
  def _max_idx(x: List[int]):
1086
1075
  return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
1087
1076
 
1088
1077
 
1089
- def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
1090
- """For a scalar or tensor t, we initialize its preconditioner Q and
1078
+ @decorator_knowngood
1079
+ def mean_root(x: torch.Tensor, pow: float):
1080
+ return stochastic_round_(x, x.float().pow(pow).mean().pow(-1 / pow / 2))
1081
+
1082
+
1083
+ @decorator_knowngood
1084
+ def divided_root(x, y, pow0, pow1):
1085
+ mean_x = x.float().pow(pow0).mean().pow(1 / pow0 / 2)
1086
+ mean_y = y.float().pow(pow1).mean().pow(-1 / pow1 / 2)
1087
+ return stochastic_round_(x, mean_x * mean_y) # multiply here, as we already divide in pow -1
1088
+
1089
+
1090
+ def precond_init_scale(scale, scale_scale, grad, hessian_vector, vector):
1091
+ if scale is not None:
1092
+ warn_once(
1093
+ "It's recommended to use precond_init_scale=None (default since 1.7.x), which uses advanced heuristics.")
1094
+ if scale_scale is not None:
1095
+ warn_once(
1096
+ "precond_init_scale_scale multiplies the precond_init_scale by a constant factor. With a fixed precond_init_scale, you should explicitly multiply it into the precond_init_scale.")
1097
+ return scale
1098
+ if hessian_vector is None:
1099
+ return mean_root(grad, 4) * scale_scale
1100
+ return divided_root(vector, hessian_vector, 2, 4) * scale_scale
1101
+
1102
+
1103
+ def init_lra(grad, scale, scale_scale, rank, hessian_vector, vector, dtype=None):
1104
+ scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1105
+ U = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1106
+ V = torch.randn((*grad.shape, rank), dtype=dtype, device=grad.device)
1107
+ d = torch.full_like(grad, scale, dtype=dtype, device=grad.device)
1108
+ return U, V, d
1109
+
1110
+
1111
+ def init_Q_exprs(grad, scale, scale_scale, max_size, min_ndim_triangular, memory_save_mode, hessian_vector, vector,
1112
+ dtype=None):
1113
+ """
1114
+ For a scalar or tensor `grad`, we initialize its preconditioner Q and
1091
1115
  reusable einsum expressions for updating Q and preconditioning gradient.
1116
+
1117
+ precond init scale computation from
1118
+ https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2208-L2227
1092
1119
  """
1120
+ scale = precond_init_scale(scale, scale_scale, grad, hessian_vector, vector)
1093
1121
  letters = string.ascii_lowercase + string.ascii_uppercase
1094
- dtype = dtype if dtype is not None else t.dtype
1095
- shape = t.shape
1122
+ dtype = dtype if dtype is not None else grad.dtype
1123
+ shape = grad.shape
1096
1124
 
1097
1125
  if len(shape) == 0: # scalar
1098
- Q = [scale * torch.ones_like(t, dtype=dtype)]
1126
+ Q = [scale * torch.ones_like(grad, dtype=dtype)]
1099
1127
  exprA = ",->"
1100
1128
  exprGs = [",->"]
1101
1129
  exprP = ",,->"
@@ -1103,7 +1131,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1103
1131
 
1104
1132
  # Tensor
1105
1133
  if len(shape) > 13:
1106
- raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!")
1134
+ raise ValueError(f"Got tensor with dim {len(grad.shape)}; Einstein runs out of letters!")
1107
1135
 
1108
1136
  scale = scale ** (1 / len(shape))
1109
1137
 
@@ -1129,7 +1157,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1129
1157
  for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
1130
1158
  if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d:
1131
1159
  # use diagonal matrix as preconditioner for this dim
1132
- Q.append(scale * torch.ones(size, dtype=promote(dtype), device=t.device))
1160
+ Q.append(scale * torch.ones(size, dtype=promote(dtype), device=grad.device))
1133
1161
 
1134
1162
  piece1A.append(letters[i])
1135
1163
  piece2A = piece2A + letters[i]
@@ -1143,13 +1171,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1143
1171
  piece4P = piece4P + letters[i + 13]
1144
1172
  else:
1145
1173
  # use triangular matrix as preconditioner for this dim
1146
- Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
1174
+ Q.append(scale * torch.eye(size, dtype=dtype, device=grad.device))
1147
1175
  piece1A.append(letters[i] + letters[i + 13])
1148
1176
  piece2A = piece2A + letters[i + 13]
1149
1177
  piece3A = piece3A + letters[i]
1150
1178
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1151
1179
  piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1152
- subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
1180
+ subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
1153
1181
  exprGs.append(subscripts)
1154
1182
  a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1155
1183
  piece1P.append(a + b)
@@ -1158,7 +1186,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1158
1186
  piece4P = piece4P + b
1159
1187
 
1160
1188
  exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
1161
- exprP = (",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P)
1189
+ exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
1162
1190
  return [Q, (exprA, tuple(exprGs), exprP)]
1163
1191
 
1164
1192
 
@@ -1170,17 +1198,171 @@ def psgd_balance_Q(Q_in):
1170
1198
  torch._foreach_mul_(Q_in, list(norms))
1171
1199
 
1172
1200
 
1201
+ @decorator
1202
+ def psgd_balance_lra(U: Tensor, V: Tensor):
1203
+ u_norm = promote(torch.linalg.vector_norm(U))
1204
+ v_norm = promote(torch.linalg.vector_norm(V))
1205
+ scale = (u_norm / v_norm) ** 0.5
1206
+ U.div_(scale)
1207
+ V.mul_(scale)
1208
+
1209
+
1210
+ @decorator
1211
+ def low_rank_mm(U: Tensor, V: Tensor, x: Tensor) -> Tensor:
1212
+ dtype = min_dtype([U, V, x])
1213
+ return x + torch.einsum("br,gr,g->b", U.to(dtype), V.to(dtype), x.to(dtype)).to(x.dtype)
1214
+
1215
+
1216
+ def update_lra_precond_(U: List[Tensor], V: List[Tensor], d: List[Tensor], vector: Tensor, hessian_vector: Tensor,
1217
+ eps: float, step: float, delayed: bool, ):
1218
+ """
1219
+ Adapted from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L657
1220
+ """
1221
+ U_orig, V_orig, d_orig = U, V, d
1222
+
1223
+ U, V, d = flatten(U, 1), flatten(V, 1), flatten(d)
1224
+
1225
+ dtype = min_dtype([U, V, vector, hessian_vector])
1226
+ U, V, vector, hessian_vector = U.to(dtype), V.to(dtype), vector.to(dtype), hessian_vector.to(dtype)
1227
+
1228
+ eps = scalar_guard(eps, vector)
1229
+
1230
+ Qh = low_rank_mm(U, V, d * hessian_vector)
1231
+ Ph = d * low_rank_mm(V, U, Qh)
1232
+ rank = U.size(1)
1233
+
1234
+ VtU = torch.einsum("br,bn->rn", V, U) # (rank, rank)
1235
+ I = torch.eye(rank, dtype=VtU.dtype, device=VtU.device)
1236
+ IpVtU = I + VtU
1237
+ invQtv = vector / d
1238
+
1239
+ # LU factorization to reuse computation
1240
+ try:
1241
+ LU, pivots = torch.linalg.lu_factor(IpVtU)
1242
+ except RuntimeError:
1243
+ # Error:
1244
+ # U[2,2] is zero and using it on lu_solve would result in a division by zero.
1245
+ # If you still want to perform the factorization, consider calling
1246
+ # linalg.lu(A, pivot) or linalg.lu_factor_ex(A, pivot)
1247
+ # ---
1248
+ # So, we skip this step and reattempt on the next one
1249
+ return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1250
+
1251
+ invQtv = invQtv - V @ torch.linalg.lu_solve(LU, pivots, (U.T @ invQtv).view(-1, 1), adjoint=True).flatten()
1252
+ invPv = invQtv - U @ torch.linalg.lu_solve(LU, pivots, (V.T @ invQtv).view(-1, 1)).flatten()
1253
+ invPv = invPv / d
1254
+
1255
+ nablaD = Ph * hessian_vector - vector * invPv
1256
+ divisor = (Ph.square() + vector.square()) * (hessian_vector.square() + invPv.square())
1257
+ divisor = divisor.add(eps).sqrt().max()
1258
+ d_step = step / divisor
1259
+
1260
+ apply_flat_add(d_orig, d * nablaD, -d_step)
1261
+
1262
+ a, b = Qh, invQtv
1263
+
1264
+ precond_u = random.random() < 0.5 # update either U or V, not both at the same time
1265
+ precond = V if precond_u else U
1266
+ atV = torch.einsum("b,br->r", a, precond) # o == one
1267
+ btV = torch.einsum("b,br->r", b, precond)
1268
+ atVVt = torch.einsum("r,br->b", atV, precond)
1269
+ btVVt = torch.einsum("r,br->b", btV, precond)
1270
+ precond_step = step / (a.norm() * atVVt.norm() + b.norm() * btVVt.norm() + eps)
1271
+ if precond_u:
1272
+ a = torch.einsum("b,r,rg->bg", a, atV, IpVtU)
1273
+ b = torch.einsum("b,r,rg->bg", b, btV, IpVtU)
1274
+ else:
1275
+ a = a + torch.einsum("br,r->b", V, atV)
1276
+ b = b + torch.einsum("br,r->b", V, btV)
1277
+ a = torch.einsum("b,r->br", a, atV)
1278
+ b = torch.einsum("b,r->br", b, btV)
1279
+ apply_flat_add(U_orig if precond_u else V_orig, b - a, precond_step)
1280
+
1281
+ if not delayed:
1282
+ stochastic_add_([d], [d * nablaD], -d_step)
1283
+ stochastic_add_([U if precond_u else V], [b - a], precond_step)
1284
+ return U.to(U_orig[0].dtype), V.to(V_orig[0].dtype), d.to(d_orig[0].dtype)
1285
+
1286
+
1287
+ def lra_precond(U, V, d, g):
1288
+ """
1289
+ As-is from https://github.com/lixilinx/psgd_torch/blob/6dbea94915679d08a289928e6431b6ce07931aaf/preconditioned_stochastic_gradient_descent.py#L744
1290
+ """
1291
+ g = low_rank_mm(U, V, d * g)
1292
+ return d * low_rank_mm(V, U, g)
1293
+
1294
+
1295
+ @decorator_knowngood
1296
+ def dampen_grad(g: Tensor, damp: float = 2 ** -13):
1297
+ # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
1298
+ v = torch.randn_like(g)
1299
+ return v, g + damp * g.abs().mean() * v
1300
+
1301
+
1302
+ @decorator_knowngood
1303
+ def apply_lra_update(params: List[Tensor], update: Tensor, U: Tensor, V: Tensor, d: Tensor):
1304
+ update = lra_precond(U, V, d, update)
1305
+ start = 0
1306
+ update = update.flatten()
1307
+ for p in params:
1308
+ size = p.numel()
1309
+ copy_stochastic_(p, update[start: start + size].view_as(p))
1310
+ start += size
1311
+
1312
+
1313
+ @decorator_knowngood
1314
+ def apply_flat_update(params: List[Tensor], update: Tensor):
1315
+ start = 0
1316
+ update = update.flatten()
1317
+ for p in params:
1318
+ size = p.numel()
1319
+ copy_stochastic_(p, update[start: start + size].view_as(p))
1320
+ start += size
1321
+
1322
+
1323
+ @decorator_knowngood
1324
+ def apply_flat_add(params: List[Tensor], update: Tensor, alpha: Tensor):
1325
+ start = 0
1326
+ update = update.flatten()
1327
+ for p in params:
1328
+ size = p.numel()
1329
+ stochastic_add_([p], [update[start: start + size].view_as(p)], alpha)
1330
+ start += size
1331
+
1332
+
1333
+ @decorator_knowngood
1334
+ def extract_from_flat_update(params: List[Tensor], update: Tensor):
1335
+ start = 0
1336
+ outputs = []
1337
+ update = update.flatten()
1338
+ for p in params:
1339
+ size = p.numel()
1340
+ outputs.append(update[start: start + size].view_as(p))
1341
+ start += size
1342
+ return outputs
1343
+
1344
+
1345
+ def flatten(x: List[Tensor], remaining: int = 0) -> Tensor:
1346
+ return torch.cat([i.flatten(0, -1 - remaining) for i in x], 0)
1347
+
1348
+
1349
+ def dampen_multiple(g: List[Tensor], damp: float = 2 ** -13):
1350
+ vs = []
1351
+ gs = []
1352
+ for g_ in g:
1353
+ v, g = dampen_grad(g_, damp)
1354
+ vs.append(v)
1355
+ gs.append(g)
1356
+ return flatten(vs), flatten(gs)
1357
+
1358
+
1173
1359
  def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1174
- eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
1175
- eps *= G.norm() / G.numel()
1176
- G = G + torch.randn_like(G) * eps
1177
- md = min_dtype(Q + [G])
1178
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1179
1360
  order = G.dim()
1180
1361
  if V is None:
1181
- conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1182
- else:
1183
- conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
1362
+ V, G = dampen_grad(G)
1363
+ conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
1364
+ md = min_dtype(Q + [G])
1365
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1184
1366
  Q = [promote(q) for q in Q]
1185
1367
  for i, q in enumerate(Q):
1186
1368
  if q.dim() <= 1:
@@ -1195,12 +1377,12 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1195
1377
 
1196
1378
  def psgd_lb(A, max_abs):
1197
1379
  A /= max_abs
1198
- a0 = torch.einsum('ij,ij->j', A, A)
1380
+ a0 = torch.einsum("ij,ij->j", A, A)
1199
1381
  i = torch.argmax(a0)
1200
1382
  x = torch.index_select(A, 1, i).flatten().contiguous()
1201
- x = torch.einsum('i,ij->j', x, A)
1383
+ x = torch.einsum("i,ij->j", x, A)
1202
1384
  x /= x.norm()
1203
- x = torch.einsum('j,kj->k', x, A)
1385
+ x = torch.einsum("j,kj->k", x, A)
1204
1386
  x = x.norm()
1205
1387
  x *= max_abs
1206
1388
  return x
@@ -1217,7 +1399,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1217
1399
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1218
1400
  term1, term2 = term1 - term2, term1 + term2
1219
1401
  term1 *= precond_lr
1220
- norm = term2.norm(float('inf'))
1402
+ norm = term2.norm(float("inf"))
1221
1403
  if q.dim() < 2:
1222
1404
  term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16)
1223
1405
  else:
@@ -1245,7 +1427,7 @@ def l2_normalization_(x, clip_at: float = 1e-8):
1245
1427
  return _compilable_l2_clip_(x, clip_at)
1246
1428
 
1247
1429
 
1248
- def l2_clip_(x, clip_at: float = 1.):
1430
+ def l2_clip_(x, clip_at: float = 1.0):
1249
1431
  x = list_guard(x)
1250
1432
  return _compilable_l2_clip_(x, clip_at)
1251
1433
 
@@ -1438,11 +1620,11 @@ def warn_once(msg):
1438
1620
 
1439
1621
 
1440
1622
  def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1441
- name: str = 'cumulative_prob'):
1442
- group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1
1623
+ name: str = "cumulative_prob"):
1624
+ group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
1443
1625
  if not isinstance(prob, float):
1444
- prob = prob(group[f'{name}_prob_step'])
1445
- if group['stochastic_schedule']:
1626
+ prob = prob(group[f"{name}_prob_step"])
1627
+ if group["stochastic_schedule"]:
1446
1628
  return rng.random() < prob
1447
1629
  cumulative_prob = group.get(name, 0)
1448
1630
  group[name] = cumulative_prob + prob
@@ -1451,7 +1633,7 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1451
1633
 
1452
1634
  @decorator_knowngood
1453
1635
  def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1454
- cast: bool = True):
1636
+ cast: bool = True):
1455
1637
  if caution:
1456
1638
  ea = _compilable_cautioning(grad, ea)
1457
1639
  md = min_dtype(list(cached_q) + [ea])
@@ -1564,15 +1746,16 @@ def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.999, flat_
1564
1746
 
1565
1747
 
1566
1748
  def merge_group(group, *tensors):
1567
- if not group.get('merge_dims', False):
1749
+ if not group.get("merge_dims", False):
1568
1750
  return tensors
1569
1751
  if isinstance(tensors[0], list):
1570
1752
  return [merge_group(group, *t) for t in tensors]
1571
1753
 
1572
1754
  out = []
1573
1755
  for t in tensors:
1574
- append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[
1575
- 'max_precond_dim'], group.get('split', False)))
1756
+ append_or_extend(out,
1757
+ dim_merger(t, group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"],
1758
+ group.get("split", False), ), )
1576
1759
  return out
1577
1760
 
1578
1761
 
@@ -1599,7 +1782,7 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1599
1782
  o = optimizer(parameters, *args, **kwargs)
1600
1783
  step_fn = o.step
1601
1784
  o.step = functools.partial(warn_once,
1602
- msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
1785
+ msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
1603
1786
 
1604
1787
  def _step(p: Tensor):
1605
1788
  seen_params.add(p)