heavyball 1.1.3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import random
3
+ import warnings
3
4
  from typing import Optional, Union, Literal
4
5
 
5
6
  import torch
@@ -251,7 +252,11 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
251
252
  step = group['step']
252
253
  if 'precondition_frequency' in group:
253
254
  return step > 0 and step % group['precondition_frequency'] == 0
254
- rng = random.Random(0x172381 ^ step)
255
+ if isinstance(step, torch.Tensor):
256
+ utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
257
+ rng = random.Random(0x172381)
258
+ else:
259
+ rng = random.Random(0x172381 ^ step)
255
260
  if 'precond_scheduler' in group:
256
261
  return utils.precond_schedule(step, group['precond_scheduler'], rng)
257
262
  if prob is not None:
@@ -414,6 +419,8 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
414
419
 
415
420
 
416
421
  class ChainOpt(utils.StatefulOptimizer):
422
+ compile_step: bool = False
423
+
417
424
  def __init__(self, params, defaults, foreach: bool, *fns):
418
425
  super().__init__(params, defaults, foreach)
419
426
  self.fns = tuple(fns)
@@ -421,23 +428,36 @@ class ChainOpt(utils.StatefulOptimizer):
421
428
  def _step(self, group):
422
429
  if 'base_lr' not in group:
423
430
  group['base_lr'] = group['lr']
424
- step = group['step'] = group.get('step', 0) + 1
425
- if group['warmup_steps'] and step < group['warmup_steps']:
426
- group['lr'] = -group['base_lr'] * step / group['warmup_steps']
427
- else:
428
- group['lr'] = -group['base_lr']
429
431
 
430
432
  vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
431
433
  if not vals:
432
434
  return
433
435
  p, g = zip(*vals)
434
436
 
437
+ for param in p:
438
+ state = self.state_(param)
439
+ if 'step' not in state:
440
+ if self.compile_step:
441
+ step = utils.scalar_guard(0, param)
442
+ state['step'] = step
443
+ step = state['step'].add_(1)
444
+ break
445
+
446
+ group['step'] = step
447
+
448
+ if group['warmup_steps'] and step < group['warmup_steps']:
449
+ group['lr'] = group['base_lr'] * step / group['warmup_steps']
450
+ else:
451
+ group['lr'] = group['base_lr']
452
+
435
453
  if not group['foreach'] or len(p) == 1:
436
454
  for param, grad in zip(p, g):
437
455
  chain(self.state_, group, [grad], [param], *self.fns)
438
- return
456
+ else:
457
+ chain(self.state_, group, g, p, *self.fns)
439
458
 
440
- chain(self.state_, group, g, p, *self.fns)
459
+ group['lr'] = None
460
+ group['step'] = None
441
461
 
442
462
 
443
463
  use_default = object()
@@ -454,15 +474,15 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
454
474
 
455
475
 
456
476
  def default(a, b):
457
- return b if a is None or a is use_default else a
477
+ return b if a is use_default else a
458
478
 
459
479
 
460
480
  # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
461
- _scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
462
- scale_by_psgd: update_by_psgd, #
463
- scale_by_adam: update_by_adam, #
464
- scale_by_laprop: update_by_laprop, #
465
- scale_by_adopt: update_by_adopt}
481
+ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
482
+ scale_by_psgd.get_fn(): update_by_psgd, #
483
+ scale_by_adam.get_fn(): update_by_adam, #
484
+ scale_by_laprop.get_fn(): update_by_laprop, #
485
+ scale_by_adopt.get_fn(): update_by_adopt}
466
486
 
467
487
 
468
488
  class BaseOpt(ChainOpt):
@@ -470,16 +490,17 @@ class BaseOpt(ChainOpt):
470
490
  update_clipping: str_or_fn = None
471
491
  palm: bool = False
472
492
  auto_fuse: bool = True
473
- compile_step: bool = False
474
493
 
475
494
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
476
- palm: bool = use_default, *fns):
495
+ palm: bool = use_default, compile_step: bool = use_default, *fns):
477
496
  if default(update_clipping, self.update_clipping) is None:
478
497
  if fns and self.auto_fuse:
479
498
  args, kwargs = None, None
480
499
  fn = fns[-1]
481
500
  if isinstance(fn, functools.partial):
482
- fn, args, kwargs = fns[-1].func, fns[-1].args, fns[-1].keywords
501
+ fn, args, kwargs = fn.func, fn.args, fn.keywords
502
+ if isinstance(fn, FunctionTransform):
503
+ fn = fn.get_fn()
483
504
  if fn in _scale_to_update_map:
484
505
  fn = _scale_to_update_map[fn]
485
506
  if args is not None:
@@ -491,6 +512,7 @@ class BaseOpt(ChainOpt):
491
512
 
492
513
  fns = tuple(fns)
493
514
 
515
+ self.compile_step = default(compile_step, self.compile_step)
494
516
  if default(palm, self.palm):
495
517
  fns = (palm_beta2,) + fns
496
518
  if default(gradient_clipping, self.gradient_clipping) is not None:
heavyball/utils.py CHANGED
@@ -1,11 +1,3 @@
1
- """
2
-
3
-
4
- Originally from Evan Walters and Omead Pooladzandi, 2024
5
- Modified under Creative Commons Attribution 4.0 International
6
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
7
- """
8
-
9
1
  import functools
10
2
  import gc
11
3
  import math
@@ -70,16 +62,16 @@ def warmup(lr: float, step: int, warmup_steps: int):
70
62
  @decorator_knowngood
71
63
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
72
64
  beta1: Tensor, decay: float):
73
- grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
74
- p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
75
- for p_, z_, g_ in zip(p32, z32, g32):
65
+ for op, oz, g_ in zip(p, z, grad):
66
+ g_ = g_.view_as(op)
67
+ p_, z_, g_ = map(promote, (op, oz, g_))
76
68
  if decay != 0:
77
- g_.add_(p_, alpha=decay)
78
- p_.lerp_(z_, ckp1)
79
- p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
80
- z_.add_(g_, alpha=lr)
81
- copy_stochastic_list_(p, p32)
82
- copy_stochastic_list_(z, z32)
69
+ g_ = g_ + p_ * decay
70
+ p_ = p_.lerp(z_, ckp1)
71
+ p_ = p_ + g_ * (lr * (beta1 * (1 - ckp1)) - lr)
72
+ z_ = z_ + g_ * -lr
73
+ copy_stochastic_(op, p_)
74
+ copy_stochastic_(oz, z_)
83
75
 
84
76
 
85
77
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
@@ -164,9 +156,9 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
164
156
  out: List[Optional[Tensor]]):
165
157
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
166
158
  s32 = torch._foreach_mul(s32, beta2)
167
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
159
+ s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
168
160
  denom = torch._foreach_sqrt(s32)
169
- [d.clamp_(min=eps) for d in denom]
161
+ denom = [d.clamp(min=eps) for d in denom]
170
162
  copy_stochastic_list_(state, s32)
171
163
 
172
164
  if out[0] is None:
@@ -184,13 +176,9 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
184
176
 
185
177
  @decorator_knowngood
186
178
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
187
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
188
- s32 = torch._foreach_mul(s32, beta2)
189
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
190
- denom = torch._foreach_sqrt(s32)
191
- [d.clamp_(min=eps) for d in denom]
179
+ g32 = promote(grad)
180
+ denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
192
181
  out = torch._foreach_div(g32, denom)
193
- copy_stochastic_list_(state, s32)
194
182
  copy_stochastic_list_(grad, out)
195
183
 
196
184
 
@@ -201,10 +189,10 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
201
189
  return grad
202
190
 
203
191
 
192
+ # TODO: This lerp was fucked - check other lerps
204
193
  @decorator_knowngood
205
194
  def _compilable_exp_avg_(state, grad, beta):
206
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
- s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
195
+ s32 = [s.lerp(g, 1 - beta) for s, g in zip(promote(state), promote(grad))]
208
196
  copy_stochastic_list_(state, s32)
209
197
  copy_stochastic_list_(grad, s32)
210
198
 
@@ -218,14 +206,16 @@ def scale_by_exp_avg_(state, grad, beta):
218
206
 
219
207
  @decorator_knowngood
220
208
  def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
221
- p_norm = torch._foreach_norm(parameters)
222
- g_norm = torch._foreach_norm(gradients)
223
- torch._foreach_maximum_(p_norm, minimum)
224
- torch._foreach_maximum_(g_norm, eps)
225
- torch._foreach_div_(p_norm, g_norm)
226
- torch._foreach_mul_(p_norm, clip_val)
227
- torch._foreach_minimum_(p_norm, 1)
228
- torch._foreach_mul_(gradients, p_norm)
209
+ p32, g32 = [list(map(promote, x)) for x in (parameters, gradients)]
210
+ p_norm = torch._foreach_norm(p32)
211
+ g_norm = torch._foreach_norm(g32)
212
+ p_norm = torch._foreach_maximum(p_norm, minimum)
213
+ g_norm = torch._foreach_maximum(g_norm, eps)
214
+ p_norm = torch._foreach_div(p_norm, g_norm)
215
+ p_norm = torch._foreach_mul(p_norm, clip_val)
216
+ p_norm = torch._foreach_minimum(p_norm, 1)
217
+ g32 = torch._foreach_mul(g32, p_norm)
218
+ copy_stochastic_list_(gradients, g32)
229
219
 
230
220
 
231
221
  def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
@@ -246,10 +236,6 @@ def is_compiling():
246
236
 
247
237
 
248
238
  def set_(dst: Tensor, src: Tensor):
249
- if not is_compiling() and src.data_ptr() == dst.data_ptr():
250
- return
251
- if src.shape != dst.shape:
252
- src = src.reshape_as(dst)
253
239
  dst.copy_(src)
254
240
 
255
241
 
@@ -306,7 +292,7 @@ def ortho(x):
306
292
  def _compilable_heavyball_momentum_(state, grad, beta):
307
293
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
308
294
  s32 = torch._foreach_mul(s32, beta)
309
- torch._foreach_add_(s32, g32)
295
+ s32 = torch._foreach_add(s32, g32)
310
296
  copy_stochastic_list_(state, s32)
311
297
  copy_stochastic_list_(grad, s32)
312
298
 
@@ -315,8 +301,8 @@ def _compilable_heavyball_momentum_(state, grad, beta):
315
301
  def _compilable_nesterov_momentum_(state, grad, beta):
316
302
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
317
303
  s32 = torch._foreach_mul(s32, beta)
318
- torch._foreach_add_(s32, g32)
319
- [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
304
+ s32 = torch._foreach_add(s32, g32)
305
+ g32 = [g + s * beta for g, s in zip(g32, s32)]
320
306
  copy_stochastic_list_(state, s32)
321
307
  copy_stochastic_list_(grad, g32)
322
308
 
@@ -353,7 +339,7 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
353
339
  elif scale_mode == "scale":
354
340
  y *= max(1, x.size(0) / x.size(1)) ** 0.5
355
341
  elif scale_mode == "graft":
356
- y *= x.norm() / y.norm().clamp_(min=1e-6)
342
+ y *= x.norm() / y.norm().clamp(min=1e-6)
357
343
  else:
358
344
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
359
345
  set_(out, y)
@@ -509,8 +495,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
509
495
  for x_, y_ in zip(x, y):
510
496
  x32 = promote(x_)
511
497
  y32 = promote(y_)
512
- x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
513
- copy_stochastic_(x_, x32)
498
+ copy_stochastic_(x_, x32 + y32 * alpha)
514
499
 
515
500
 
516
501
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
@@ -634,10 +619,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
634
619
  def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
635
620
  beta1: float = -1.0):
636
621
  for p in group["params"]:
637
- if skip_none and p.grad is None:
638
- continue
639
-
640
622
  if p.grad is None:
623
+ if skip_none:
624
+ continue
641
625
  grad = None
642
626
  else:
643
627
  if should_promote:
@@ -792,7 +776,7 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
792
776
  exp_avg32 = _lerp32(exp_avg, u32, beta1)
793
777
  denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
794
778
  u32 = torch._foreach_div(exp_avg32, denom)
795
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
779
+ _compilable_update_(y, u32, decay, lr, caution, g32)
796
780
 
797
781
 
798
782
  def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
@@ -837,7 +821,7 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
837
821
  denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
838
822
  u32 = torch._foreach_div(u32, denom)
839
823
  u32 = _lerp32(exp_avg, u32, beta1)
840
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
824
+ _compilable_update_(y, u32, decay, lr, caution, gp32)
841
825
 
842
826
 
843
827
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
@@ -850,22 +834,19 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
850
834
  @decorator_knowngood
851
835
  def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
852
836
  u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
853
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
837
+ _compilable_update_(y, u32, decay, lr, caution, g32)
854
838
 
855
839
  beta1 = beta_debias(beta1, step)
856
840
  denom = torch._foreach_sqrt(exp_avg_sq32)
857
- [denom.clamp_(min=eps) for denom in denom]
858
- exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
859
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
841
+ denom = [d.clamp(min=eps) for d in denom]
842
+ exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
860
843
  copy_stochastic_list_(exp_avg, exp_avg32)
861
844
 
862
845
  beta2 = beta_debias(beta2, step + 1)
863
- exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
864
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
846
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
865
847
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
866
848
 
867
849
 
868
-
869
850
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
870
851
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
871
852
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
@@ -879,14 +860,12 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
879
860
 
880
861
  beta1 = beta_debias(beta1, step)
881
862
  denom = torch._foreach_sqrt(exp_avg_sq32)
882
- [denom.clamp_(min=1e-8) for denom in denom]
883
- exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
884
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
863
+ denom = [d.clamp(min=1e-8) for d in denom]
864
+ exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
885
865
  copy_stochastic_list_(exp_avg, exp_avg32)
886
866
 
887
867
  beta2 = beta_debias(beta2, step + 1)
888
- exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
889
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
868
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
890
869
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
891
870
 
892
871
  copy_stochastic_list_(grad, update)
@@ -921,39 +900,31 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
921
900
 
922
901
 
923
902
  def copy_stochastic_(target: Tensor, source: Tensor):
924
- if not is_compiling() and target.data_ptr() == source.data_ptr():
925
- return
926
903
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
927
904
  _compilable_copy_stochastic_(target, source.float())
928
905
  set_(target, source)
929
906
 
930
907
 
931
908
  @decorator_knowngood
932
- def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
909
+ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
933
910
  g: List[Optional[Tensor]]):
934
911
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
935
912
  p32, u32 = [list(map(promote, x)) for x in [p, u]]
936
913
 
937
- if decay > 0:
938
- torch._foreach_mul_(p32, 1 - decay * lr)
939
-
940
- for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
914
+ for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
941
915
  if caution:
942
916
  u32_ = _compilable_cautioning(promote(g_), u32_)
943
- add_fn(p32_, u32_, lr)
944
-
945
- copy_stochastic_list_(p, p32)
917
+ p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
918
+ copy_stochastic_(p_, p32_)
946
919
 
947
920
 
948
- def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
949
- caution: bool = False, grad: List[Tensor] = None):
921
+ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
922
+ grad: List[Tensor] = None):
950
923
  param, update, grad = list_guard(param, update, grad)
951
924
  lr = scalar_guard(lr, param[0])
952
925
  if not caution:
953
926
  grad = [None] * len(param)
954
- if add_fn is None:
955
- add_fn = stochastic_add_
956
- _compilable_update_(param, update, decay, add_fn, lr, caution, grad)
927
+ _compilable_update_(param, update, decay, lr, caution, grad)
957
928
 
958
929
 
959
930
  def precond_schedule(step, precond_scheduler, rng):
@@ -1194,6 +1165,7 @@ def identity(x):
1194
1165
 
1195
1166
  @decorator_knowngood
1196
1167
  def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1168
+ # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
1197
1169
  g32 = list(map(promote, grad))
1198
1170
  [g.mul_(1 / scale) for g in g32]
1199
1171
  tanh = torch._foreach_tanh(g32)
@@ -1247,6 +1219,12 @@ def update_triu_(q_state, materialised):
1247
1219
  assert shape0 == shape1
1248
1220
  copy_stochastic_(q, m)
1249
1221
 
1222
+ _warned = set()
1223
+
1224
+ def warn_once(msg):
1225
+ if msg not in _warned:
1226
+ warnings.warn(msg)
1227
+ _warned.add(msg)
1250
1228
 
1251
1229
  def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1252
1230
  name: str = 'cumulative_prob'):
@@ -1291,6 +1269,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1291
1269
  return new.to(ea.dtype)
1292
1270
 
1293
1271
 
1272
+ @decorator_knowngood
1294
1273
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1295
1274
  precond = psgd_precond_grad(expr, grad, *preconds)
1296
1275
  update_param_(param, precond, lr, decay, caution=caution, grad=grad)
@@ -1371,3 +1350,22 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1371
1350
 
1372
1351
  for p in model.parameters():
1373
1352
  p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
1353
+
1354
+
1355
+ def fused_hook(parameters, optimizer, *args, **kwargs):
1356
+ parameters = list(parameters)
1357
+ param_count = len(parameters)
1358
+ seen_params = set()
1359
+
1360
+ o = optimizer(parameters, *args, **kwargs)
1361
+
1362
+ def _step(p: Tensor):
1363
+ seen_params.add(p)
1364
+
1365
+ if len(seen_params) < param_count:
1366
+ o.step()
1367
+ o.zero_grad()
1368
+ seen_params.clear()
1369
+
1370
+ for p in parameters:
1371
+ p.register_post_accumulate_grad_hook(_step)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.3
3
+ Version: 1.2.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=5CrtSVaTI9DIgPPy0DD3WbWyVmc6-3jd2E5zM2frQlI,21092
3
+ heavyball/utils.py,sha256=H0r2GpqRS1c6qIYqW5rFYA-020AVVVWbfGne17mzlcM,47377
4
+ heavyball-1.2.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.2.0.dist-info/METADATA,sha256=YzMGNrvU_RIKGn13r8GO8kp05s9Me5PWyD3KvEd09Uo,12022
6
+ heavyball-1.2.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.2.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.2.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=IrTlhiHvBlc1SaUUQDb1ulVfj0nrPJcqoP52AXWO3cI,20362
3
- heavyball/utils.py,sha256=0j5wRDYeI9Elz9m8tcP7CZNhj_9OIWEF_uQpb0LTrYM,47814
4
- heavyball-1.1.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.1.3.dist-info/METADATA,sha256=ILjX-OviZL9vO7bgw7ldGx-8XXIa1YfWRYCW94EtAOI,12022
6
- heavyball-1.1.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.1.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.1.3.dist-info/RECORD,,