heavyball 0.18.8__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,8 +2,19 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
6
- exp_avg_sq_, beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group
5
+ from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
7
+ promote
8
+
9
+
10
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
11
+ def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
12
+ eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
13
+ denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
14
+ torch._foreach_div_(gp32, denom)
15
+
16
+ copy_stochastic_list_(exp_avg_sq, eas32)
17
+ copy_stochastic_list_(grad_projected, gp32)
7
18
 
8
19
 
9
20
  class PrecondScheduleSFPaLMSOAP(ScheduleFree):
@@ -40,8 +51,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
40
51
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
41
52
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
42
53
  data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
43
- weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
44
- betas=(None, None), split: bool = False, foreach: bool = True):
54
+ weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
55
+ split: bool = False, foreach: bool = True):
45
56
  if betas[0] is not None:
46
57
  beta = betas[0]
47
58
  defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
@@ -103,8 +114,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
103
114
 
104
115
  # Decay the first and second moment running average coefficient
105
116
  # In-place operations to update the averages at the same time
106
- denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
107
- torch._foreach_div_(grad_projected, denom)
117
+ old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
118
+ _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
108
119
 
109
120
  update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
110
121
 
@@ -114,13 +125,12 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
114
125
  # to the original space
115
126
  set_(gp, project(gp, state['Q'], back=True))
116
127
 
117
- update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
118
- update_precond)
128
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
119
129
 
120
130
  # Weight decay calculated at y
121
131
  if group["weight_decay"] > 0:
122
132
  torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
123
133
 
124
134
  lr = warmup(group['lr'], step, group['warmup_steps'])
125
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
126
- p_list, z, grad_projected, group['r'], step)
135
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
136
+ z, grad_projected, group['r'], step)
heavyball/psgd_kron.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
9
  import torch
10
10
 
11
11
  from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
12
- split_p_and_g_in_group, line_to_triu, triu_to_line, set_, promote
12
+ split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
13
13
 
14
14
 
15
15
  class ForeachPSGDKron(PSGDBase):
@@ -39,7 +39,8 @@ class ForeachPSGDKron(PSGDBase):
39
39
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
40
40
  momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
- foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True, #
42
+ foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
+ storage_dtype: str = 'float32', #
43
44
  # expert parameters
44
45
  precond_init_scale=1.0, precond_lr=0.1):
45
46
  if not 0.0 <= lr:
@@ -56,7 +57,7 @@ class ForeachPSGDKron(PSGDBase):
56
57
  min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
57
58
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
58
59
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
59
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
60
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
60
61
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
61
62
 
62
63
  def _step(self, group):
@@ -72,14 +73,15 @@ class ForeachPSGDKron(PSGDBase):
72
73
  beta = group['beta']
73
74
  store_triu_as_line = group['store_triu_as_line']
74
75
  q_dtype = getattr(torch, group['q_dtype'])
76
+ storage_dtype = getattr(torch, group['storage_dtype'])
75
77
 
76
78
  vals = []
77
79
 
78
- for p, g in split_p_and_g_in_group(group):
80
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
79
81
  state = self.state_(p)
80
82
 
81
83
  if 'Q' not in state:
82
- state["exp_avg"] = torch.zeros_like(g)
84
+ state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
83
85
  Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
84
86
  memory_save_mode, dtype=q_dtype)
85
87
  state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
@@ -94,9 +96,14 @@ class ForeachPSGDKron(PSGDBase):
94
96
 
95
97
  group["step"] += 1
96
98
 
97
- torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
99
+ beta = beta_debias(beta, group["step"])
100
+ beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
101
+ stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
98
102
 
99
103
  grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
104
+
105
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
106
+
100
107
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
101
108
  q_orig = Q_list.pop(0)
102
109
  ea = exp_avg_list.pop(0)
@@ -106,9 +113,5 @@ class ForeachPSGDKron(PSGDBase):
106
113
  q32 = [promote(q_) for q_ in q]
107
114
  self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
108
115
  store_triu_as_line)
109
- set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
110
-
111
- grad_list = self.clip_fn(grad_list)
112
-
113
- lr = -warmup(lr, group['step'], group['warmup_steps'])
114
- update_param_(p_list, grad_list, lr, weight_decay)
116
+ g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
117
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
heavyball/pure_psgd.py CHANGED
@@ -70,7 +70,7 @@ class ForeachPurePSGD(PSGDBase):
70
70
 
71
71
  vals = []
72
72
 
73
- for p, g in split_p_and_g_in_group(group):
73
+ for p, g in split_p_and_g_in_group(group, should_promote=False):
74
74
  state = self.state_(p)
75
75
 
76
76
  if 'Q' not in state:
@@ -89,6 +89,7 @@ class ForeachPurePSGD(PSGDBase):
89
89
  group["step"] += 1
90
90
 
91
91
  Q_list = list(Q_list)
92
+ lr = -warmup(lr, group['step'], group['warmup_steps'])
92
93
  for i, (p, g) in enumerate(zip(p_list, grad_list)):
93
94
  q_orig = Q_list.pop(0)
94
95
  q = line_to_triu(q_orig) if store_triu_as_line else q_orig
@@ -97,8 +98,4 @@ class ForeachPurePSGD(PSGDBase):
97
98
  q32 = [promote(q_) for q_ in q]
98
99
  self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
99
100
  psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
100
-
101
- grad_list = self.clip_fn(grad_list)
102
-
103
- lr = -warmup(lr, group['step'], group['warmup_steps'])
104
- update_param_(p_list, grad_list, lr, weight_decay)
101
+ update_param_([p], self.clip_fn([g]), lr, weight_decay)
@@ -2,8 +2,18 @@ import random
2
2
 
3
3
  import torch
4
4
 
5
- from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
6
- exp_avg_sq_, beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group
5
+ from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
6
+ beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
7
+
8
+
9
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
10
+ def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
11
+ eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
12
+ denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
13
+ torch._foreach_div_(gp32, denom)
14
+
15
+ copy_stochastic_list_(exp_avg_sq, eas32)
16
+ copy_stochastic_list_(grad_projected, gp32)
7
17
 
8
18
 
9
19
  class SFPaLMForeachSOAP(ScheduleFree):
@@ -95,8 +105,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
95
105
 
96
106
  # Decay the first and second moment running average coefficient
97
107
  # In-place operations to update the averages at the same time
98
- denom = exp_avg_sq_(exp_avg_sq, grad, new_debiased2, group["eps"])
99
- torch._foreach_div_(grad_projected, denom)
108
+ old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(new_debiased2)
109
+ _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
100
110
 
101
111
  update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
102
112
 
@@ -107,13 +117,12 @@ class SFPaLMForeachSOAP(ScheduleFree):
107
117
  # CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
108
118
  set_(gp, project(gp, state['Q'], back=True))
109
119
 
110
- update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
111
- update_precond)
120
+ update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2, update_precond)
112
121
 
113
122
  # Weight decay calculated at y
114
123
  if group["weight_decay"] > 0:
115
124
  torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
116
125
 
117
126
  lr = warmup(group['lr'], step, group['warmup_steps'])
118
- group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
119
- p_list, z, grad_projected, group['r'], step)
127
+ group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
128
+ z, grad_projected, group['r'], step)
heavyball/utils.py CHANGED
@@ -3,7 +3,7 @@ import gc
3
3
  import math
4
4
  import random
5
5
  import string
6
- from typing import List, Optional, Tuple, Callable
6
+ from typing import List, Optional, Tuple, Callable, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -141,6 +141,7 @@ def beta_debias(beta, step):
141
141
  return 1 - (1 - beta) / (1 - beta ** step)
142
142
 
143
143
 
144
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
144
145
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
145
146
  if isinstance(state, torch.Tensor):
146
147
  state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
@@ -327,6 +328,26 @@ def get_orthogonal_matrix(mat):
327
328
  return final
328
329
 
329
330
 
331
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
332
+ def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
333
+ x32 = [promote(x_) for x_ in x]
334
+ y32 = [promote(y_) for y_ in y]
335
+
336
+ torch._foreach_lerp_(x32, y32, a)
337
+
338
+ copy_stochastic_list_(x, x32)
339
+
340
+
341
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
342
+ def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
343
+ x32 = [promote(x_) for x_ in x]
344
+ y32 = [promote(y_) for y_ in y]
345
+
346
+ [x_.add_(y_, alpha=alpha) for x_, y_ in zip(x32, y32)]
347
+
348
+ copy_stochastic_list_(x, x32)
349
+
350
+
330
351
  @decorator
331
352
  def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
332
353
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
@@ -409,9 +430,12 @@ def project(grad, Q, back: bool):
409
430
 
410
431
 
411
432
  class StatefulOptimizer(torch.optim.Optimizer):
412
- def __init__(self, params, defaults, foreach: bool = True):
433
+ ema_decay: float = 0.001
434
+
435
+ def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
413
436
  super().__init__(params, {**defaults, 'foreach': foreach})
414
437
  self.fake_groups = {}
438
+ self.use_ema = use_ema
415
439
 
416
440
  def key(self, param: torch.Tensor):
417
441
  return (param.data_ptr(), tuple(param.shape))
@@ -445,6 +469,54 @@ class StatefulOptimizer(torch.optim.Optimizer):
445
469
  def _step(self, group):
446
470
  raise NotImplementedError
447
471
 
472
+ def ema_update(self):
473
+ with torch.no_grad():
474
+ for top_group in self.param_groups:
475
+ for group in self.get_groups(top_group):
476
+ active_p = [p for p in group['params']]
477
+
478
+ if not active_p:
479
+ return
480
+
481
+ k = group['ema_step'] = group.get('ema_step', -1) + 1
482
+
483
+ for p in active_p:
484
+ if 'param_ema' not in self.state_(p):
485
+ self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
486
+
487
+ y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
488
+ torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
489
+
490
+ def copy_emas_to_params(self):
491
+ with torch.no_grad():
492
+ for top_group in self.param_groups:
493
+ for group in self.get_groups(top_group):
494
+ active_p = [p for p in group['params']]
495
+
496
+ if not active_p:
497
+ return
498
+
499
+ for p in active_p:
500
+ if 'param_ema' in self.state_(p):
501
+ p_clone = p.data.clone()
502
+ set_(p.data, self.state_(p)['param_ema'])
503
+ set_(self.state_(p)['param_ema'], p_clone)
504
+
505
+ def copy_params_to_emas(self):
506
+ with torch.no_grad():
507
+ for top_group in self.param_groups:
508
+ for group in self.get_groups(top_group):
509
+ active_p = [p for p in group['params']]
510
+
511
+ if not active_p:
512
+ return
513
+
514
+ for p in active_p:
515
+ if 'param_ema' in self.state_(p):
516
+ ema_clone = self.state_(p)['param_ema'].data.clone()
517
+ set_(self.state_(p)['param_ema'], p.data)
518
+ set_(p.data, ema_clone)
519
+
448
520
  def step(self, closure: Optional[Callable] = None):
449
521
  if closure is None:
450
522
  loss = None
@@ -455,6 +527,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
455
527
  for top_group in self.param_groups:
456
528
  for group in self.get_groups(top_group):
457
529
  self._step(group)
530
+ if self.use_ema:
531
+ self.ema_update(group)
458
532
  return loss
459
533
 
460
534
 
@@ -497,6 +571,20 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
497
571
  copy_stochastic_(t, s)
498
572
 
499
573
 
574
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
575
+ def exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
576
+ beta1 = beta_debias(beta1, step)
577
+ beta2 = beta_debias(beta2, step)
578
+
579
+ g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
580
+
581
+ stochastic_lerp_(exp_avg, g32, 1 - beta1)
582
+ denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
583
+
584
+ copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
585
+ return denom
586
+
587
+
500
588
  # this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
501
589
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
502
590
  def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
@@ -523,23 +611,26 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
523
611
 
524
612
 
525
613
  @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
526
- def _compilable_update_one_(p, u, decay, add_fn, lr):
527
- p32 = promote(p)
528
- u32 = promote(u.view(p.shape))
614
+ def _compilable_update_(p, u, decay, add_fn, lr):
615
+ u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
616
+ p32, u32 = [list(map(promote, x)) for x in [p, u]]
617
+
529
618
  if decay > 0:
530
- p32.mul_(1 - decay * lr)
531
- if add_fn is None:
532
- p32.add_(u32, alpha=lr)
533
- else:
534
- add_fn(p32, u32, lr)
535
- copy_stochastic_(p, p32)
619
+ torch._foreach_mul_(p32, 1 - decay * lr)
620
+
621
+ for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
622
+ if add_fn is None:
623
+ p32_.add_(u32_, alpha=lr)
624
+ else:
625
+ add_fn(p32_, u32_, lr)
626
+
627
+ copy_stochastic_list_(p, p32)
536
628
 
537
629
 
538
630
  def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
539
631
  add_fn: callable = None):
540
632
  lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
541
- for p, u in zip(param, update):
542
- _compilable_update_one_(p, u, decay, add_fn, lr_tensor)
633
+ _compilable_update_(param, update, decay, add_fn, lr_tensor)
543
634
 
544
635
 
545
636
  def precond_schedule(step, precond_scheduler, rng):
@@ -630,7 +721,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
630
721
  return [Q, (exprA, tuple(exprGs), exprP)]
631
722
 
632
723
 
633
- @decorator
724
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
634
725
  def psgd_balance_Q(Q_in):
635
726
  norms = torch.stack([q.norm(float("inf")) for q in Q_in])
636
727
  geometric_mean = norms.log().mean().exp()
@@ -638,12 +729,14 @@ def psgd_balance_Q(Q_in):
638
729
  torch._foreach_mul_(Q_in, list(norms))
639
730
 
640
731
 
641
- def psgd_calc_A_and_conjB(exprA, G, Q, V):
642
- md = min_dtype(Q)
643
- A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
732
+ def psgd_calc_A_and_conjB(exprA, G, Q):
733
+ md = min_dtype(Q + [G])
734
+ A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
644
735
  order = G.dim()
645
736
  p = list(range(order))
646
- conjB = torch.permute(V.conj(), p[1:] + p[:1])
737
+ V = torch.randn_like(G, dtype=promote(G.dtype))
738
+ conjB = torch.permute(V, p[1:] + p[:1])
739
+ Q = [promote(q) for q in Q]
647
740
  for i, q in enumerate(Q):
648
741
  if q.dim() <= 1:
649
742
  conjB /= q
@@ -651,7 +744,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
651
744
  unsqueeze = conjB.dim() <= 1
652
745
  if unsqueeze:
653
746
  conjB = conjB.unsqueeze(0)
654
- conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False, out=conjB)
747
+ conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False)
655
748
  if unsqueeze:
656
749
  conjB = conjB.squeeze(0)
657
750
  if i < order - 1:
@@ -661,33 +754,37 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
661
754
 
662
755
  def psgd_lb(A, max_abs):
663
756
  A /= max_abs
664
- aa = torch.real(A * A.conj())
665
- value0, i = torch.max(torch.sum(aa, dim=0), 0)
666
- value1, j = torch.max(torch.sum(aa, dim=1), 0)
757
+ a0 = torch.einsum('ij,ij->j', A, A)
758
+ a1 = torch.einsum('ij,ij->i', A, A)
759
+ value0 = torch.max(a0)
760
+ value1 = torch.max(a1)
761
+ i = torch.argmax(a0)
762
+ j = torch.argmax(a1)
667
763
 
668
- ah = A.H
669
764
  comp = value0 > value1
670
- x = torch.where(comp, A[:, i], A[j])
671
- x = x.conj()
672
- if x.dim() > 1:
673
- x = torch.where(comp, x, x.T)
674
- torch.matmul(x, torch.where(comp, A, A.T), out=x.view(1, -1))
675
- x /= torch.linalg.vector_norm(x)
676
- torch.matmul(x, torch.where(comp, ah, ah.T), out=x.view(1, -1))
677
- x = torch.linalg.vector_norm(x)
765
+ x = torch.cond(comp, lambda a: torch.index_select(a, 1, i).flatten().contiguous(), #
766
+ lambda a: torch.index_select(a, 0, j).flatten().contiguous(), (A,))
767
+
768
+ x = torch.cond(comp, lambda x_, a: torch.einsum('i,ij->j', x_, a), lambda x_, a: torch.einsum('i,ji->j', x_, a),
769
+ (x, A,))
770
+ x /= x.norm()
771
+ x = torch.cond(comp, lambda x_, a: torch.einsum('j,kj->k', x_, a), lambda x_, a: torch.einsum('j,jk->k', x_, a),
772
+ (x, A,))
773
+ x = x.norm()
678
774
  x *= max_abs
679
775
  return x
680
776
 
681
777
 
682
- def psgd_update_precond(Q, exprs, V, G, step, tiny):
778
+ @torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=False)
779
+ def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
683
780
  """Update Kronecker product preconditioner Q with pair (V, G)."""
684
781
  exprA, exprGs, _ = exprs
685
782
 
686
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
783
+ A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
687
784
 
688
- for q, exprG in zip(Q, exprGs):
689
- term1 = torch.einsum(exprG, A, A.conj())
690
- term2 = torch.einsum(exprG, conjB.conj(), conjB)
785
+ for q, exprG, o in zip(Q, exprGs, oq):
786
+ term1 = promote(torch.einsum(exprG, A, A))
787
+ term2 = promote(torch.einsum(exprG, conjB, conjB))
691
788
 
692
789
  term2 += term1 # a + b
693
790
  term1 *= 2 # 2a
@@ -696,15 +793,19 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
696
793
  else:
697
794
  term1 = term1 - term2
698
795
 
699
- term1 *= step
796
+ term1 *= precond_lr
700
797
  norm = term2.norm(float('inf'))
701
798
  if q.dim() < 2:
702
- term1 *= q
703
- q.addcdiv_(term1, norm.clamp_(min=tiny), value=-1)
799
+ term1 *= q.to(term1.dtype)
800
+ term1 /= norm.clamp_(min=tiny)
704
801
  else:
705
802
  torch.triu(term1, out=term1)
706
- term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny)
707
- q.addmm_(term1, q, alpha=-1)
803
+ term1 /= psgd_lb(term2, norm).clamp_(tiny)
804
+ torch.matmul(term1, q, out=term1)
805
+ if store_triu_as_line:
806
+ term1 = triu_to_line([term1])[0][1]
807
+ o = o[1]
808
+ stochastic_add_([o], [term1], -1)
708
809
 
709
810
 
710
811
  @decorator
@@ -838,18 +939,9 @@ class PSGDBase(StatefulOptimizer):
838
939
  group[name] = cumulative_prob + prob
839
940
  return int(group[name]) > int(cumulative_prob)
840
941
 
841
- def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: Optional[List] = None,
842
- store_triu_as_line=False):
843
- if original_q:
844
- if store_triu_as_line:
845
- update_fn = update_triu_
846
- else:
847
- update_fn = copy_stochastic_list_
848
- else:
849
- update_fn = lambda x, y: None
850
- for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
851
- psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
852
- update_fn(oq, Q)
942
+ def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
943
+ for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
944
+ psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
853
945
 
854
946
  if self.should_update(group, self.balance_probability, "balance_prob"):
855
947
  for g, q in zip(grad_list, original_q if original_q else q_list):
@@ -896,13 +988,19 @@ def merge_group(group, *tensors):
896
988
  return out
897
989
 
898
990
 
899
- def split_p_and_g_in_group(group: dict, skip_none: bool = True):
991
+ def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
900
992
  for p in group["params"]:
901
993
  if skip_none and p.grad is None:
902
994
  continue
903
995
 
904
- grad = None if p.grad is None else promote(p.grad)
905
- p.grad = None
996
+ if p.grad is None:
997
+ grad = None
998
+ else:
999
+ if should_promote:
1000
+ grad = promote(p.grad)
1001
+ else:
1002
+ grad = p.grad
1003
+ p.grad = None
906
1004
 
907
1005
  p_views = merge_group(group, p)
908
1006
  if grad is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.18.8
3
+ Version: 0.20.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
32
32
  The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
33
  largely static alternative to `torch.optim` with more and better optimizers.
34
34
 
35
- Currently (2024-11-21, 0.18.6), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
35
+ Currently (2024-11-22, 0.19.1), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
36
  recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
37
 
38
38
  ## Features
@@ -45,8 +45,10 @@ recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psg
45
45
  * **ScheduleFree**: No learning rate schedule, but better convergence
46
46
  * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
47
47
  better step-per-second in late convergence (explained below)
48
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) to trade off memory usage for memory
49
- bandwidth; turn it off for lower overheads (for more, see [PSGD Efficiency](docs/psgd_efficiency.md))
48
+ * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
49
+ usage for memory
50
+ bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
51
+ stochastic rounding
50
52
 
51
53
  ## Getting started
52
54
 
@@ -76,19 +78,19 @@ for _ in range(1000):
76
78
 
77
79
  ## Optimizers
78
80
 
79
- | Name | Description | Advantages / Disadvantages |
80
- |-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
81
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
82
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
83
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
84
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
85
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
86
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
87
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
88
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
81
+ | Name | Description | Advantages / Disadvantages |
82
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
83
+ | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
84
+ | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
85
+ | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
86
+ | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
87
+ | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
88
+ | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
89
+ | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
90
+ | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
89
91
  | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
90
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
91
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
+ | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
93
+ | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
92
94
 
93
95
  ## Precond Schedule
94
96