heavyball 1.1.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import random
3
+ import warnings
3
4
  from typing import Optional, Union, Literal
4
5
 
5
6
  import torch
@@ -251,7 +252,11 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
251
252
  step = group['step']
252
253
  if 'precondition_frequency' in group:
253
254
  return step > 0 and step % group['precondition_frequency'] == 0
254
- rng = random.Random(0x172381 ^ step)
255
+ if isinstance(step, torch.Tensor):
256
+ utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
257
+ rng = random.Random(0x172381)
258
+ else:
259
+ rng = random.Random(0x172381 ^ step)
255
260
  if 'precond_scheduler' in group:
256
261
  return utils.precond_schedule(step, group['precond_scheduler'], rng)
257
262
  if prob is not None:
@@ -414,6 +419,8 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
414
419
 
415
420
 
416
421
  class ChainOpt(utils.StatefulOptimizer):
422
+ compile_step: bool = False
423
+
417
424
  def __init__(self, params, defaults, foreach: bool, *fns):
418
425
  super().__init__(params, defaults, foreach)
419
426
  self.fns = tuple(fns)
@@ -421,23 +428,36 @@ class ChainOpt(utils.StatefulOptimizer):
421
428
  def _step(self, group):
422
429
  if 'base_lr' not in group:
423
430
  group['base_lr'] = group['lr']
424
- step = group['step'] = group.get('step', 0) + 1
425
- if group['warmup_steps'] and step < group['warmup_steps']:
426
- group['lr'] = -group['base_lr'] * step / group['warmup_steps']
427
- else:
428
- group['lr'] = -group['base_lr']
429
431
 
430
432
  vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
431
433
  if not vals:
432
434
  return
433
435
  p, g = zip(*vals)
434
436
 
437
+ for param in p:
438
+ state = self.state_(param)
439
+ if 'step' not in state:
440
+ if self.compile_step:
441
+ step = utils.scalar_guard(0, param)
442
+ state['step'] = step
443
+ step = state['step'].add_(1)
444
+ break
445
+
446
+ group['step'] = step
447
+
448
+ if group['warmup_steps'] and step < group['warmup_steps']:
449
+ group['lr'] = group['base_lr'] * step / group['warmup_steps']
450
+ else:
451
+ group['lr'] = group['base_lr']
452
+
435
453
  if not group['foreach'] or len(p) == 1:
436
454
  for param, grad in zip(p, g):
437
455
  chain(self.state_, group, [grad], [param], *self.fns)
438
- return
456
+ else:
457
+ chain(self.state_, group, g, p, *self.fns)
439
458
 
440
- chain(self.state_, group, g, p, *self.fns)
459
+ group['lr'] = None
460
+ group['step'] = None
441
461
 
442
462
 
443
463
  use_default = object()
@@ -454,15 +474,15 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
454
474
 
455
475
 
456
476
  def default(a, b):
457
- return b if a is None or a is use_default else a
477
+ return b if a is use_default else a
458
478
 
459
479
 
460
480
  # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
461
- _scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
462
- scale_by_psgd: update_by_psgd, #
463
- scale_by_adam: update_by_adam, #
464
- scale_by_laprop: update_by_laprop, #
465
- scale_by_adopt: update_by_adopt}
481
+ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
482
+ scale_by_psgd.get_fn(): update_by_psgd, #
483
+ scale_by_adam.get_fn(): update_by_adam, #
484
+ scale_by_laprop.get_fn(): update_by_laprop, #
485
+ scale_by_adopt.get_fn(): update_by_adopt}
466
486
 
467
487
 
468
488
  class BaseOpt(ChainOpt):
@@ -470,16 +490,17 @@ class BaseOpt(ChainOpt):
470
490
  update_clipping: str_or_fn = None
471
491
  palm: bool = False
472
492
  auto_fuse: bool = True
473
- compile_step: bool = False
474
493
 
475
494
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
476
- palm: bool = use_default, *fns):
495
+ palm: bool = use_default, compile_step: bool = use_default, *fns):
477
496
  if default(update_clipping, self.update_clipping) is None:
478
497
  if fns and self.auto_fuse:
479
498
  args, kwargs = None, None
480
499
  fn = fns[-1]
481
500
  if isinstance(fn, functools.partial):
482
- fn, args, kwargs = fns[-1].func, fns[-1].args, fns[-1].keywords
501
+ fn, args, kwargs = fn.func, fn.args, fn.keywords
502
+ if isinstance(fn, FunctionTransform):
503
+ fn = fn.get_fn()
483
504
  if fn in _scale_to_update_map:
484
505
  fn = _scale_to_update_map[fn]
485
506
  if args is not None:
@@ -491,6 +512,7 @@ class BaseOpt(ChainOpt):
491
512
 
492
513
  fns = tuple(fns)
493
514
 
515
+ self.compile_step = default(compile_step, self.compile_step)
494
516
  if default(palm, self.palm):
495
517
  fns = (palm_beta2,) + fns
496
518
  if default(gradient_clipping, self.gradient_clipping) is not None:
heavyball/utils.py CHANGED
@@ -1,16 +1,9 @@
1
- """
2
-
3
-
4
- Originally from Evan Walters and Omead Pooladzandi, 2024
5
- Modified under Creative Commons Attribution 4.0 International
6
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
7
- """
8
-
9
1
  import functools
10
2
  import gc
11
3
  import math
12
4
  import random
13
5
  import string
6
+ import warnings
14
7
  from typing import List, Optional, Tuple, Callable, Union
15
8
 
16
9
  import numpy as np
@@ -70,16 +63,16 @@ def warmup(lr: float, step: int, warmup_steps: int):
70
63
  @decorator_knowngood
71
64
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
72
65
  beta1: Tensor, decay: float):
73
- grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
74
- p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
75
- for p_, z_, g_ in zip(p32, z32, g32):
66
+ for op, oz, g_ in zip(p, z, grad):
67
+ g_ = g_.view_as(op)
68
+ p_, z_, g_ = map(promote, (op, oz, g_))
76
69
  if decay != 0:
77
- g_.add_(p_, alpha=decay)
78
- p_.lerp_(z_, ckp1)
79
- p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
80
- z_.add_(g_, alpha=lr)
81
- copy_stochastic_list_(p, p32)
82
- copy_stochastic_list_(z, z32)
70
+ g_ = g_ + p_ * decay
71
+ p_ = p_.lerp(z_, ckp1)
72
+ p_ = p_ + g_ * (lr * (beta1 * (1 - ckp1)) - lr)
73
+ z_ = z_ + g_ * -lr
74
+ copy_stochastic_(op, p_)
75
+ copy_stochastic_(oz, z_)
83
76
 
84
77
 
85
78
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
@@ -164,9 +157,9 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
164
157
  out: List[Optional[Tensor]]):
165
158
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
166
159
  s32 = torch._foreach_mul(s32, beta2)
167
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
160
+ s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
168
161
  denom = torch._foreach_sqrt(s32)
169
- [d.clamp_(min=eps) for d in denom]
162
+ denom = [d.clamp(min=eps) for d in denom]
170
163
  copy_stochastic_list_(state, s32)
171
164
 
172
165
  if out[0] is None:
@@ -184,13 +177,9 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
184
177
 
185
178
  @decorator_knowngood
186
179
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
187
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
188
- s32 = torch._foreach_mul(s32, beta2)
189
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
190
- denom = torch._foreach_sqrt(s32)
191
- [d.clamp_(min=eps) for d in denom]
180
+ g32 = promote(grad)
181
+ denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
192
182
  out = torch._foreach_div(g32, denom)
193
- copy_stochastic_list_(state, s32)
194
183
  copy_stochastic_list_(grad, out)
195
184
 
196
185
 
@@ -201,10 +190,10 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
201
190
  return grad
202
191
 
203
192
 
193
+ # TODO: This lerp was fucked - check other lerps
204
194
  @decorator_knowngood
205
195
  def _compilable_exp_avg_(state, grad, beta):
206
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
- s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
196
+ s32 = [s.lerp(g, 1 - beta) for s, g in zip(promote(state), promote(grad))]
208
197
  copy_stochastic_list_(state, s32)
209
198
  copy_stochastic_list_(grad, s32)
210
199
 
@@ -218,14 +207,16 @@ def scale_by_exp_avg_(state, grad, beta):
218
207
 
219
208
  @decorator_knowngood
220
209
  def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
221
- p_norm = torch._foreach_norm(parameters)
222
- g_norm = torch._foreach_norm(gradients)
223
- torch._foreach_maximum_(p_norm, minimum)
224
- torch._foreach_maximum_(g_norm, eps)
225
- torch._foreach_div_(p_norm, g_norm)
226
- torch._foreach_mul_(p_norm, clip_val)
227
- torch._foreach_minimum_(p_norm, 1)
228
- torch._foreach_mul_(gradients, p_norm)
210
+ p32, g32 = [list(map(promote, x)) for x in (parameters, gradients)]
211
+ p_norm = torch._foreach_norm(p32)
212
+ g_norm = torch._foreach_norm(g32)
213
+ p_norm = torch._foreach_maximum(p_norm, minimum)
214
+ g_norm = torch._foreach_maximum(g_norm, eps)
215
+ p_norm = torch._foreach_div(p_norm, g_norm)
216
+ p_norm = torch._foreach_mul(p_norm, clip_val)
217
+ p_norm = torch._foreach_minimum(p_norm, 1)
218
+ g32 = torch._foreach_mul(g32, p_norm)
219
+ copy_stochastic_list_(gradients, g32)
229
220
 
230
221
 
231
222
  def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
@@ -246,10 +237,6 @@ def is_compiling():
246
237
 
247
238
 
248
239
  def set_(dst: Tensor, src: Tensor):
249
- if not is_compiling() and src.data_ptr() == dst.data_ptr():
250
- return
251
- if src.shape != dst.shape:
252
- src = src.reshape_as(dst)
253
240
  dst.copy_(src)
254
241
 
255
242
 
@@ -306,7 +293,7 @@ def ortho(x):
306
293
  def _compilable_heavyball_momentum_(state, grad, beta):
307
294
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
308
295
  s32 = torch._foreach_mul(s32, beta)
309
- torch._foreach_add_(s32, g32)
296
+ s32 = torch._foreach_add(s32, g32)
310
297
  copy_stochastic_list_(state, s32)
311
298
  copy_stochastic_list_(grad, s32)
312
299
 
@@ -315,8 +302,8 @@ def _compilable_heavyball_momentum_(state, grad, beta):
315
302
  def _compilable_nesterov_momentum_(state, grad, beta):
316
303
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
317
304
  s32 = torch._foreach_mul(s32, beta)
318
- torch._foreach_add_(s32, g32)
319
- [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
305
+ s32 = torch._foreach_add(s32, g32)
306
+ g32 = [g + s * beta for g, s in zip(g32, s32)]
320
307
  copy_stochastic_list_(state, s32)
321
308
  copy_stochastic_list_(grad, g32)
322
309
 
@@ -353,7 +340,7 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
353
340
  elif scale_mode == "scale":
354
341
  y *= max(1, x.size(0) / x.size(1)) ** 0.5
355
342
  elif scale_mode == "graft":
356
- y *= x.norm() / y.norm().clamp_(min=1e-6)
343
+ y *= x.norm() / y.norm().clamp(min=1e-6)
357
344
  else:
358
345
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
359
346
  set_(out, y)
@@ -509,8 +496,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
509
496
  for x_, y_ in zip(x, y):
510
497
  x32 = promote(x_)
511
498
  y32 = promote(y_)
512
- x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
513
- copy_stochastic_(x_, x32)
499
+ copy_stochastic_(x_, x32 + y32 * alpha)
514
500
 
515
501
 
516
502
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
@@ -634,10 +620,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
634
620
  def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
635
621
  beta1: float = -1.0):
636
622
  for p in group["params"]:
637
- if skip_none and p.grad is None:
638
- continue
639
-
640
623
  if p.grad is None:
624
+ if skip_none:
625
+ continue
641
626
  grad = None
642
627
  else:
643
628
  if should_promote:
@@ -792,7 +777,7 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
792
777
  exp_avg32 = _lerp32(exp_avg, u32, beta1)
793
778
  denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
794
779
  u32 = torch._foreach_div(exp_avg32, denom)
795
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
780
+ _compilable_update_(y, u32, decay, lr, caution, g32)
796
781
 
797
782
 
798
783
  def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
@@ -837,7 +822,7 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
837
822
  denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
838
823
  u32 = torch._foreach_div(u32, denom)
839
824
  u32 = _lerp32(exp_avg, u32, beta1)
840
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
825
+ _compilable_update_(y, u32, decay, lr, caution, gp32)
841
826
 
842
827
 
843
828
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
@@ -850,22 +835,19 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
850
835
  @decorator_knowngood
851
836
  def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
852
837
  u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
853
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
838
+ _compilable_update_(y, u32, decay, lr, caution, g32)
854
839
 
855
840
  beta1 = beta_debias(beta1, step)
856
841
  denom = torch._foreach_sqrt(exp_avg_sq32)
857
- [denom.clamp_(min=eps) for denom in denom]
858
- exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
859
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
842
+ denom = [d.clamp(min=eps) for d in denom]
843
+ exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
860
844
  copy_stochastic_list_(exp_avg, exp_avg32)
861
845
 
862
846
  beta2 = beta_debias(beta2, step + 1)
863
- exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
864
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
847
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
865
848
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
866
849
 
867
850
 
868
-
869
851
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
870
852
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
871
853
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
@@ -879,14 +861,12 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
879
861
 
880
862
  beta1 = beta_debias(beta1, step)
881
863
  denom = torch._foreach_sqrt(exp_avg_sq32)
882
- [denom.clamp_(min=1e-8) for denom in denom]
883
- exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
884
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
864
+ denom = [d.clamp(min=1e-8) for d in denom]
865
+ exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
885
866
  copy_stochastic_list_(exp_avg, exp_avg32)
886
867
 
887
868
  beta2 = beta_debias(beta2, step + 1)
888
- exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
889
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
869
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
890
870
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
891
871
 
892
872
  copy_stochastic_list_(grad, update)
@@ -921,39 +901,31 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
921
901
 
922
902
 
923
903
  def copy_stochastic_(target: Tensor, source: Tensor):
924
- if not is_compiling() and target.data_ptr() == source.data_ptr():
925
- return
926
904
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
927
905
  _compilable_copy_stochastic_(target, source.float())
928
906
  set_(target, source)
929
907
 
930
908
 
931
909
  @decorator_knowngood
932
- def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
910
+ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
933
911
  g: List[Optional[Tensor]]):
934
912
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
935
913
  p32, u32 = [list(map(promote, x)) for x in [p, u]]
936
914
 
937
- if decay > 0:
938
- torch._foreach_mul_(p32, 1 - decay * lr)
939
-
940
- for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
915
+ for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
941
916
  if caution:
942
917
  u32_ = _compilable_cautioning(promote(g_), u32_)
943
- add_fn(p32_, u32_, lr)
944
-
945
- copy_stochastic_list_(p, p32)
918
+ p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
919
+ copy_stochastic_(p_, p32_)
946
920
 
947
921
 
948
- def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
949
- caution: bool = False, grad: List[Tensor] = None):
922
+ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
923
+ grad: List[Tensor] = None):
950
924
  param, update, grad = list_guard(param, update, grad)
951
925
  lr = scalar_guard(lr, param[0])
952
926
  if not caution:
953
927
  grad = [None] * len(param)
954
- if add_fn is None:
955
- add_fn = stochastic_add_
956
- _compilable_update_(param, update, decay, add_fn, lr, caution, grad)
928
+ _compilable_update_(param, update, decay, lr, caution, grad)
957
929
 
958
930
 
959
931
  def precond_schedule(step, precond_scheduler, rng):
@@ -1194,6 +1166,7 @@ def identity(x):
1194
1166
 
1195
1167
  @decorator_knowngood
1196
1168
  def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1169
+ # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
1197
1170
  g32 = list(map(promote, grad))
1198
1171
  [g.mul_(1 / scale) for g in g32]
1199
1172
  tanh = torch._foreach_tanh(g32)
@@ -1247,6 +1220,12 @@ def update_triu_(q_state, materialised):
1247
1220
  assert shape0 == shape1
1248
1221
  copy_stochastic_(q, m)
1249
1222
 
1223
+ _warned = set()
1224
+
1225
+ def warn_once(msg):
1226
+ if msg not in _warned:
1227
+ warnings.warn(msg)
1228
+ _warned.add(msg)
1250
1229
 
1251
1230
  def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1252
1231
  name: str = 'cumulative_prob'):
@@ -1291,6 +1270,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1291
1270
  return new.to(ea.dtype)
1292
1271
 
1293
1272
 
1273
+ @decorator_knowngood
1294
1274
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1295
1275
  precond = psgd_precond_grad(expr, grad, *preconds)
1296
1276
  update_param_(param, precond, lr, decay, caution=caution, grad=grad)
@@ -1371,3 +1351,22 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1371
1351
 
1372
1352
  for p in model.parameters():
1373
1353
  p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
1354
+
1355
+
1356
+ def fused_hook(parameters, optimizer, *args, **kwargs):
1357
+ parameters = list(parameters)
1358
+ param_count = len(parameters)
1359
+ seen_params = set()
1360
+
1361
+ o = optimizer(parameters, *args, **kwargs)
1362
+
1363
+ def _step(p: Tensor):
1364
+ seen_params.add(p)
1365
+
1366
+ if len(seen_params) < param_count:
1367
+ o.step()
1368
+ o.zero_grad()
1369
+ seen_params.clear()
1370
+
1371
+ for p in parameters:
1372
+ p.register_post_accumulate_grad_hook(_step)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.3
3
+ Version: 1.2.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=5CrtSVaTI9DIgPPy0DD3WbWyVmc6-3jd2E5zM2frQlI,21092
3
+ heavyball/utils.py,sha256=n80bsTZ2NUD4L3YERaC7ydKaQzW4kSDNoBTLF0DfE1g,47393
4
+ heavyball-1.2.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.2.1.dist-info/METADATA,sha256=eBqL5rxjms6zSclp1IjQKI8vmv1OOgvrmdOroZnE1GQ,12022
6
+ heavyball-1.2.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.2.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.2.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=IrTlhiHvBlc1SaUUQDb1ulVfj0nrPJcqoP52AXWO3cI,20362
3
- heavyball/utils.py,sha256=0j5wRDYeI9Elz9m8tcP7CZNhj_9OIWEF_uQpb0LTrYM,47814
4
- heavyball-1.1.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.1.3.dist-info/METADATA,sha256=ILjX-OviZL9vO7bgw7ldGx-8XXIa1YfWRYCW94EtAOI,12022
6
- heavyball-1.1.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.1.3.dist-info/RECORD,,