heavyball 1.4.4__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -116,7 +116,7 @@ class ForeachSOAP(C.BaseOpt):
116
116
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
117
117
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
118
118
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
119
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
119
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
120
120
  split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
121
121
  mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
122
122
  beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
@@ -129,8 +129,10 @@ class ForeachSOAP(C.BaseOpt):
129
129
 
130
130
  if use_precond_schedule:
131
131
  del defaults['precondition_frequency']
132
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
132
133
  else:
133
134
  del defaults['precond_scheduler']
135
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
134
136
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
135
137
  C.scale_by_soap)
136
138
 
@@ -149,6 +151,53 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
149
151
  palm: bool = True
150
152
 
151
153
 
154
+ class OrthoLaProp(C.BaseOpt):
155
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
156
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
157
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
158
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
159
+ defaults = locals()
160
+ defaults.pop("self")
161
+ params = defaults.pop("params")
162
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
163
+ C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
+
165
+
166
+ class ForeachAdamW(C.BaseOpt):
167
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
168
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
169
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
170
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
171
+ defaults = locals()
172
+ defaults.pop("self")
173
+ params = defaults.pop("params")
174
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
175
+
176
+
177
+ class OrthoAdamW(C.BaseOpt):
178
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
179
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
180
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
181
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
182
+ defaults = locals()
183
+ defaults.pop("self")
184
+ params = defaults.pop("params")
185
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
186
+ C.orthogonalize_grad_to_param, C.scale_by_adam)
187
+
188
+
189
+ class AdamWOrtho(C.BaseOpt):
190
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
191
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
192
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
193
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
194
+ defaults = locals()
195
+ defaults.pop("self")
196
+ params = defaults.pop("params")
197
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
198
+ C.orthogonalize_grad_to_param)
199
+
200
+
152
201
  class ForeachPSGDKron(C.BaseOpt):
153
202
  """
154
203
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -162,7 +211,7 @@ class ForeachPSGDKron(C.BaseOpt):
162
211
 
163
212
  def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
164
213
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
165
- momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
214
+ momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
166
215
  split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
167
216
  stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
168
217
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
@@ -172,6 +221,8 @@ class ForeachPSGDKron(C.BaseOpt):
172
221
  precond_init_scale=1.0, precond_lr=0.1):
173
222
  defaults = locals()
174
223
  defaults.pop("self")
224
+ self.precond_schedule = defaults.pop(
225
+ "preconditioner_update_probability") or utils.precond_update_prob_schedule()
175
226
  params = defaults.pop("params")
176
227
 
177
228
  delayed = C.default(delayed, self.delayed)
@@ -181,8 +232,7 @@ class ForeachPSGDKron(C.BaseOpt):
181
232
 
182
233
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
183
234
  *(C.exp_avg,) * exp_avg_input, #
184
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
185
- prob=preconditioner_update_probability))
235
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached))
186
236
 
187
237
 
188
238
  class ForeachPurePSGD(ForeachPSGDKron):
@@ -202,6 +252,10 @@ class ForeachDelayedPSGD(ForeachPSGDKron):
202
252
  delayed: bool = True
203
253
 
204
254
 
255
+ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
256
+ hessian_approx = True
257
+
258
+
205
259
  PalmForEachSoap = PaLMForeachSOAP
206
260
  PaLMSOAP = PaLMForeachSOAP
207
261
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -225,4 +279,4 @@ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD",
225
279
  "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
226
280
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
227
281
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
228
- "ForeachRMSprop", "ForeachMuon"]
282
+ "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
heavyball/chainable.py CHANGED
@@ -127,7 +127,7 @@ def zero_guard(*names):
127
127
 
128
128
 
129
129
  def copy_guard(index, *names):
130
- return functools.partial(CopyGuard, index=index, names=names,)
130
+ return functools.partial(CopyGuard, index=index, names=names, )
131
131
 
132
132
 
133
133
  def general_guard(*names, init_fn, skip_first: bool = True):
@@ -188,6 +188,11 @@ def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
188
188
  raise SkipUpdate
189
189
 
190
190
 
191
+ @no_state
192
+ def orthogonalize_grad_to_param(group, update, grad, param):
193
+ return utils.orthogonalize_grad_to_param(param, update, group['eps'])
194
+
195
+
191
196
  @copy_guard(2, "z")
192
197
  @no_state
193
198
  def update_by_schedule_free(group, update, grad, param, z):
@@ -312,17 +317,23 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
312
317
 
313
318
  for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
314
319
  utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
315
- utils.beta_debias(group['shampoo_beta'], group['step']), precond_schedule(group))
320
+ utils.beta_debias(group['shampoo_beta'], group['step']),
321
+ group['is_preconditioning'])
316
322
  return precond
317
323
 
318
324
 
319
325
  def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, prob: Optional[callable] = None):
320
326
  if prob is None:
321
327
  prob = utils.precond_update_prob_schedule()
322
- if not precond_schedule(group, prob, name=f"cumulative_prob_{id(Q)}"):
328
+
329
+ if not group['is_preconditioning']:
323
330
  return Q_mat
324
331
 
325
- utils.psgd_update_precond(Q_mat, exprs, grad, group['precond_lr'], Q, group['store_triu_as_line'])
332
+ utils.psgd_update_precond(Q_mat, exprs, getattr(param, 'hessian_vector', grad), group['precond_lr'], Q,
333
+ group['store_triu_as_line'], getattr(param, 'vector', None))
334
+ if hasattr(param, 'vector'):
335
+ del param.vector
336
+ del param.hessian_vector
326
337
 
327
338
  if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
328
339
  if group['store_triu_as_line']:
@@ -330,7 +341,15 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
330
341
  else:
331
342
  utils.psgd_balance_Q(Q)
332
343
 
333
- return _update_psgd_cache(cached, Q_cache, Q_mat)
344
+ if isinstance(prob, float):
345
+ float_prob = prob
346
+ else:
347
+ float_prob = prob(group.get(f'cumulative_prob_{id(Q)}_prob_step', 1))
348
+ group['is_cached'] = should_use_cache = cached and float_prob < 0.5
349
+
350
+ if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step
351
+ return _update_psgd_cache(cached, Q_cache, Q_mat)
352
+ return Q_mat
334
353
 
335
354
 
336
355
  def _update_psgd_cache(cached, Q_cache, q):
@@ -345,14 +364,14 @@ def _update_psgd_cache(cached, Q_cache, q):
345
364
  return Q_cache
346
365
 
347
366
 
348
- def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
349
- if cached:
367
+ def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
368
+ if group.get('is_cached', False):
350
369
  return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
351
370
  return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
352
371
 
353
372
 
354
- def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, exprs, update, Q_mat, Q_cache):
355
- if cached:
373
+ def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
374
+ if group.get('is_cached', False):
356
375
  utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'],
357
376
  group['caution'], *Q_cache)
358
377
  else:
@@ -368,7 +387,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
368
387
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
369
388
  Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
370
389
  update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
371
- return _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
390
+ return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
372
391
 
373
392
 
374
393
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
@@ -376,9 +395,9 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
376
395
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
377
396
  prob: Optional[callable] = None):
378
397
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
379
- precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
380
- _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
381
- Q_mat, Q, exprs, prob)
398
+ precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
399
+ _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
400
+ Q_mat, Q, exprs, prob)
382
401
  return precond
383
402
 
384
403
 
@@ -389,7 +408,7 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
389
408
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
390
409
  Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
391
410
  update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
392
- _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
411
+ _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
393
412
  raise SkipUpdate
394
413
 
395
414
 
@@ -398,9 +417,9 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
398
417
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
399
418
  prob: Optional[callable] = None):
400
419
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
401
- _fused_cached_psgd_precond_grad(group, update, param, cached, cache_expr, exprs, update, Q_mat, Q_cache)
402
- _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
403
- Q_mat, Q, exprs, prob)
420
+ _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache)
421
+ _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
422
+ Q_mat, Q, exprs, prob)
404
423
  raise SkipUpdate
405
424
 
406
425
 
@@ -449,6 +468,7 @@ class ChainOpt(utils.StatefulOptimizer):
449
468
  group['base_lr'] = group['lr']
450
469
 
451
470
  vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
471
+
452
472
  if not vals:
453
473
  return
454
474
  p, g = zip(*vals)
@@ -464,12 +484,7 @@ class ChainOpt(utils.StatefulOptimizer):
464
484
  break
465
485
 
466
486
  group['step'] = state['step'] = step = step + 1
467
-
468
- if group['warmup_steps'] and step < group['warmup_steps']:
469
- group['prev_lr'] = group['lr'] = group['base_lr'] * step / group['warmup_steps']
470
-
471
- else:
472
- group['prev_lr'] = group['lr'] = group['base_lr']
487
+ group['prev_lr'] = group['lr'] = group['base_lr'] * step / max(step, group['warmup_steps'] + 1)
473
488
 
474
489
  if not group['foreach'] or len(p) == 1:
475
490
  for param, grad in zip(p, g):
heavyball/utils.py CHANGED
@@ -54,12 +54,6 @@ def decorator_knowngood(func: Callable):
54
54
  einsum_base = string.ascii_lowercase + string.ascii_uppercase
55
55
 
56
56
 
57
- def warmup(lr: float, step: int, warmup_steps: int):
58
- if step >= warmup_steps: # if instead of min to guard against 0 div
59
- return lr
60
- return lr * step / warmup_steps
61
-
62
-
63
57
  @decorator_knowngood
64
58
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor,
65
59
  beta1: Tensor, decay: float, grad: List[Tensor], caution):
@@ -323,6 +317,10 @@ def nesterov_momentum(state, grad, beta):
323
317
  return grad
324
318
 
325
319
 
320
+ def _compilable_grafting(magnitude, direction):
321
+ return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
322
+
323
+
326
324
  # mode in ("newtonschulz", "qr", "svd")
327
325
  # scale_mode in ("none", "scale", "graft")
328
326
  @decorator_knowngood
@@ -341,12 +339,17 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
341
339
  elif scale_mode == "scale":
342
340
  y *= max(1, x.size(0) / x.size(1)) ** 0.5
343
341
  elif scale_mode == "graft":
344
- y *= x.norm() / y.norm().clamp(min=1e-6)
342
+ y = _compilable_grafting(x, y)
345
343
  else:
346
344
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
347
345
  set_(out, y)
348
346
 
349
347
 
348
+ @decorator_knowngood
349
+ def _compilable_scatter_set(target, source, index):
350
+ target[:] = source.contiguous()[index].reshape_as(target)
351
+
352
+
350
353
  def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
351
354
  """
352
355
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -381,7 +384,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
381
384
 
382
385
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
383
386
  for i, ind in enumerate(indices))
384
- set_(exp_avg_sq, exp_avg_sq[indices])
387
+ _compilable_scatter_set(exp_avg_sq, exp_avg_sq, indices)
385
388
 
386
389
 
387
390
  def get_orthogonal_matrix(mat):
@@ -500,7 +503,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
500
503
  copy_stochastic_(x_, x32 + y32 * alpha)
501
504
 
502
505
 
503
- def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
506
+ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor] = 1):
504
507
  x, y = list_guard(x, y)
505
508
  alpha = scalar_guard(alpha, x[0])
506
509
  _compilable_stochastic_add_(x, y, alpha)
@@ -591,25 +594,26 @@ def project(grad, Q, back: bool):
591
594
  class StatefulOptimizer(torch.optim.Optimizer):
592
595
  ema_decay: float = 0.001
593
596
  compile_step: bool = False
597
+ hessian_approx: bool = False
598
+ precond_schedule: Union[Callable, float, None] = None
599
+ stochastic_schedule: bool = False
594
600
 
595
601
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
596
602
  super().__init__(params, {**defaults, 'foreach': foreach})
597
- self.fake_groups = {}
598
603
  self.use_ema = use_ema
599
604
  self.mapping = {}
605
+ self._inner_group = {'stochastic_schedule': self.stochastic_schedule}
606
+ self._precond_rng = random.Random(0x12312)
607
+ self._is_preconditioning = None
600
608
 
601
- def get_groups(self, group):
602
- if group['foreach']:
603
- return [group]
604
-
605
- for p in group['params']:
606
- if p not in self.fake_groups:
607
- self.fake_groups[p] = {**group, 'params': [p]}
609
+ if self.hessian_approx and self.compile_step:
610
+ raise ValueError("Hessian approximation can't be used with compile_step.")
608
611
 
609
- return [self.fake_groups[p] for p in group['params']]
612
+ def get_groups(self, group):
613
+ return [group]
610
614
 
611
615
  def state_(self, arg: Tensor):
612
- return self.state[self.mapping.get(arg, arg)]
616
+ return self.state[arg]
613
617
 
614
618
  def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta):
615
619
  for p, g in zip(p_list, g_list):
@@ -622,36 +626,27 @@ class StatefulOptimizer(torch.optim.Optimizer):
622
626
  def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
623
627
  beta1: float = -1.0):
624
628
  for p in group["params"]:
625
- if p.grad is None:
626
- if skip_none:
627
- continue
628
- grad = None
629
+ if p in self.mapping:
630
+ p_views = self.mapping[p]
629
631
  else:
630
- if should_promote:
631
- grad = promote(p.grad)
632
- else:
633
- grad = p.grad
634
- if beta1 >= 0 and group.get('mars', False):
635
- self.mars_correct_list(group, [p], [grad], group['mars_gamma'], beta1)
632
+ self.mapping[p] = p_views = merge_group(group, p)
636
633
 
637
- p.grad = None
634
+ grad = getattr(p, 'grad', None)
635
+ p.grad = None
638
636
 
639
- if self.compile_step:
640
- yield p, grad
641
- continue
642
-
643
- p_views = merge_group(group, p)
644
- if grad is not None:
645
- grad = merge_group(group, grad)
646
- for i, pv in enumerate(p_views):
647
- self.mapping[pv] = (p, i)
648
- if isinstance(p_views, Tensor):
649
- yield p_views, grad
650
- continue
651
637
  if grad is None:
652
- yield from zip(p_views, [None] * len(p_views))
653
- continue
654
- yield from zip(p_views, grad)
638
+ grad = [getattr(pv, 'grad', None) for pv in p_views]
639
+ else:
640
+ grad = merge_group(group, grad)
641
+
642
+ for pv, g in zip(p_views, grad):
643
+ if skip_none and g is None:
644
+ continue
645
+ if should_promote:
646
+ g = promote(g)
647
+ if beta1 >= 0 and group.get('mars', False):
648
+ self.mars_correct_list(group, [pv], [g], group['mars_gamma'], beta1)
649
+ yield pv, g
655
650
 
656
651
  def state_size(self) -> int:
657
652
  total_bytes = 0
@@ -671,67 +666,89 @@ class StatefulOptimizer(torch.optim.Optimizer):
671
666
 
672
667
  def ema_update(self):
673
668
  with torch.no_grad():
674
- for top_group in self.param_groups:
675
- for group in self.get_groups(top_group):
676
- active_p = [p for p in group['params']]
669
+ for group in self.param_groups:
670
+ active_p = [p for p in group['params']]
677
671
 
678
- if not active_p:
679
- return
672
+ if not active_p:
673
+ return
680
674
 
681
- k = group['ema_step'] = group.get('ema_step', -1) + 1
675
+ k = group['ema_step'] = group.get('ema_step', -1) + 1
682
676
 
683
- for p in active_p:
684
- if 'param_ema' not in self.state_(p):
685
- self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
677
+ for p in active_p:
678
+ if 'param_ema' not in self.state_(p):
679
+ self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
686
680
 
687
- y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
688
- torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
681
+ y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
682
+ torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
689
683
 
690
684
  def copy_emas_to_params(self):
691
685
  with torch.no_grad():
692
- for top_group in self.param_groups:
693
- for group in self.get_groups(top_group):
694
- active_p = [p for p in group['params']]
686
+ for group in self.param_groups:
687
+ active_p = [p for p in group['params']]
695
688
 
696
- if not active_p:
697
- return
689
+ if not active_p:
690
+ return
698
691
 
699
- for p in active_p:
700
- if 'param_ema' in self.state_(p):
701
- p_clone = p.data.clone()
702
- set_(p.data, self.state_(p)['param_ema'])
703
- set_(self.state_(p)['param_ema'], p_clone)
692
+ for p in active_p:
693
+ if 'param_ema' in self.state_(p):
694
+ p_clone = p.data.clone()
695
+ set_(p.data, self.state_(p)['param_ema'])
696
+ set_(self.state_(p)['param_ema'], p_clone)
704
697
 
705
698
  def copy_params_to_emas(self):
706
699
  with torch.no_grad():
707
- for top_group in self.param_groups:
708
- for group in self.get_groups(top_group):
709
- active_p = [p for p in group['params']]
700
+ for group in self.param_groups:
701
+ active_p = [p for p in group['params']]
710
702
 
711
- if not active_p:
712
- return
703
+ if not active_p:
704
+ return
713
705
 
714
- for p in active_p:
715
- if 'param_ema' in self.state_(p):
716
- ema_clone = self.state_(p)['param_ema'].data.clone()
717
- set_(self.state_(p)['param_ema'], p.data)
718
- set_(p.data, ema_clone)
706
+ for p in active_p:
707
+ if 'param_ema' in self.state_(p):
708
+ ema_clone = self.state_(p)['param_ema'].data.clone()
709
+ set_(self.state_(p)['param_ema'], p.data)
710
+ set_(p.data, ema_clone)
719
711
 
720
712
  def step(self, closure: Optional[Callable] = None):
713
+ if self.precond_schedule is None:
714
+ self._is_preconditioning = False
715
+ else:
716
+ self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
717
+ hessian_approx = self.hessian_approx and self._is_preconditioning
721
718
  if closure is None:
719
+ if hessian_approx:
720
+ raise ValueError("Hessian approximation requires a closure.")
722
721
  loss = None
723
722
  else:
724
723
  with torch.enable_grad():
725
724
  loss = closure()
725
+ if hessian_approx:
726
+ grads = []
727
+ for group in self.param_groups:
728
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
729
+ grads.append(g)
730
+ p.vector = torch.randn_like(p)
731
+ p.orig = p.data.clone()
732
+ stochastic_add_(p.data, p.vector, tiny_bf16)
733
+
734
+ with torch.enable_grad():
735
+ closure()
736
+
737
+ for group in self.param_groups:
738
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
739
+ p.grad = grads.pop(0)
740
+ stochastic_add_(g, p.grad, -1)
741
+ p.hessian_vector = g
742
+ p.data.copy_(p.orig)
743
+ del p.orig
726
744
 
727
745
  # we assume that parameters are constant and that there are no excessive recompiles
728
746
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
729
- for top_group in self.param_groups:
730
- for group in self.get_groups(top_group):
731
- self._step(group)
732
- self.mapping.clear()
733
- if self.use_ema:
734
- self.ema_update(group)
747
+ for group in self.param_groups:
748
+ group['is_preconditioning'] = self._is_preconditioning
749
+ self._step(group)
750
+ if self.use_ema:
751
+ self.ema_update(group)
735
752
 
736
753
  return loss
737
754
 
@@ -753,22 +770,23 @@ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
753
770
 
754
771
  @decorator_knowngood
755
772
  def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
756
- step: Tensor):
773
+ step: Tensor, eps: Tensor):
757
774
  beta1 = beta_debias(beta1, step)
758
775
  beta2 = beta_debias(beta2, step)
759
776
 
760
777
  g32 = list(map(promote, grad))
761
778
 
762
779
  exp_avg32 = _lerp32(exp_avg, g32, beta1)
763
- denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
780
+ denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
764
781
  u32 = torch._foreach_div(exp_avg32, denom)
765
782
  copy_stochastic_list_(grad, u32)
766
783
 
767
784
 
768
- def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
785
+ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
786
+ eps: float):
769
787
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
770
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
771
- _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
788
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
789
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
772
790
  return grad
773
791
 
774
792
 
@@ -943,6 +961,15 @@ def precond_schedule(step, precond_scheduler, rng):
943
961
  return update_precond
944
962
 
945
963
 
964
+ def get_soap_precond_schedule(precond_scheduler):
965
+ rng = random.Random(0x12312)
966
+
967
+ def _inner(step):
968
+ return precond_schedule(step, precond_scheduler, rng)
969
+
970
+ return _inner
971
+
972
+
946
973
  def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
947
974
  """For a scalar or tensor t, we initialize its preconditioner Q and
948
975
  reusable einsum expressions for updating Q and preconditioning gradient.
@@ -1030,14 +1057,17 @@ def psgd_balance_Q(Q_in):
1030
1057
  torch._foreach_mul_(Q_in, list(norms))
1031
1058
 
1032
1059
 
1033
- def psgd_calc_A_and_conjB(exprA, G, Q):
1060
+ def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1034
1061
  eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1035
1062
  eps *= G.norm() / G.numel()
1036
1063
  G = G + torch.randn_like(G) * eps
1037
1064
  md = min_dtype(Q + [G])
1038
1065
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1039
1066
  order = G.dim()
1040
- conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1067
+ if V is None:
1068
+ conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1069
+ else:
1070
+ conjB = V.permute(*range(1, order), 0).to(promote(G.dtype))
1041
1071
  Q = [promote(q) for q in Q]
1042
1072
  for i, q in enumerate(Q):
1043
1073
  if q.dim() <= 1:
@@ -1066,11 +1096,11 @@ def psgd_lb(A, max_abs):
1066
1096
 
1067
1097
 
1068
1098
  @decorator
1069
- def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
1099
+ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1070
1100
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1071
1101
  exprA, exprGs, _ = exprs
1072
1102
 
1073
- A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
1103
+ A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1074
1104
 
1075
1105
  for q, exprG, o in zip(Q, exprGs, oq):
1076
1106
  term1 = promote(torch.einsum(exprG, A, A))
@@ -1325,6 +1355,29 @@ def mars_correction(g, old_g, beta1, gamma):
1325
1355
  _compilable_mars_correction_(g, old_g, a)
1326
1356
 
1327
1357
 
1358
+ @decorator_knowngood
1359
+ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps: Tensor, graft: bool = True):
1360
+ """
1361
+ Implements OrthoGrad from "Grokking at the Edge of Numerical Stability" (https://arxiv.org/abs/2501.04697)
1362
+ """
1363
+
1364
+ for w, g in zip(weight, grad):
1365
+ proj = promote((w * g).sum()) / promote((w * w).sum()).add(eps)
1366
+ out = promote(g) - proj * promote(w) # promote in this funky way to keep traffic minimal
1367
+
1368
+ if graft:
1369
+ out = _compilable_grafting(g, out)
1370
+
1371
+ copy_stochastic_(g, out)
1372
+
1373
+
1374
+ def orthogonalize_grad_to_param(weight, grad, eps, graft=True):
1375
+ weight, grad = list_guard(weight, grad)
1376
+ eps = scalar_guard(eps, weight[0])
1377
+ _compilable_orthogonalization(weight, grad, eps, graft)
1378
+ return grad
1379
+
1380
+
1328
1381
  @decorator_knowngood
1329
1382
  def _compilable_cautioning(g: Tensor, update: Tensor):
1330
1383
  mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
@@ -1390,12 +1443,15 @@ def fused_hook(parameters, optimizer, *args, **kwargs):
1390
1443
  seen_params = set()
1391
1444
 
1392
1445
  o = optimizer(parameters, *args, **kwargs)
1446
+ step_fn = o.step
1447
+ o.step = functools.partial(warn_once,
1448
+ msg="You're trying to call `step` on a fused optimizer. This will not do anything.")
1393
1449
 
1394
1450
  def _step(p: Tensor):
1395
1451
  seen_params.add(p)
1396
1452
 
1397
1453
  if len(seen_params) < param_count:
1398
- o.step()
1454
+ step_fn()
1399
1455
  o.zero_grad()
1400
1456
  seen_params.clear()
1401
1457
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.4.4
3
+ Version: 1.5.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=fz-jC7m7XIYNf4PRaJ0rkSnWPYzMWEK5JQl4vp_yw_w,14166
2
+ heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
3
+ heavyball/utils.py,sha256=hae6gPVONG5lZiKm-Wqk0Sjjq3prfZIjCP5UoWcpptA,50338
4
+ heavyball-1.5.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.5.1.dist-info/METADATA,sha256=ww9KSe8MJDnjz1blmtnubpE20bkuXJ8NeMOeDK40OJk,43584
6
+ heavyball-1.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.5.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.5.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
3
- heavyball/utils.py,sha256=lFwN8T-dlldmOe-Qd6iWhSqqNfWl7IBawLWAo5l9rPw,48071
4
- heavyball-1.4.4.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.4.4.dist-info/METADATA,sha256=w5nAamE6sr08elqo2fS6B_kXktOMXxFQvyJTkRT4Eqo,43584
6
- heavyball-1.4.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.4.4.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.4.4.dist-info/RECORD,,