heavyball 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import random
3
- from typing import Optional, Union
3
+ import warnings
4
+ from typing import Optional, Union, Literal
4
5
 
5
6
  import torch
6
7
 
@@ -51,8 +52,7 @@ class FunctionTransform:
51
52
 
52
53
 
53
54
  def _zero_guard(state, key, ref, dtype):
54
- return _guard_in_state(state, key,
55
- lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
55
+ return _guard_in_state(state, key, lambda: torch.zeros_like(ref, dtype=dtype, memory_format=torch.preserve_format))
56
56
 
57
57
 
58
58
  def _storage_dtype(group):
@@ -252,7 +252,11 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
252
252
  step = group['step']
253
253
  if 'precondition_frequency' in group:
254
254
  return step > 0 and step % group['precondition_frequency'] == 0
255
- rng = random.Random(0x172381 ^ step)
255
+ if isinstance(step, torch.Tensor):
256
+ utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
257
+ rng = random.Random(0x172381)
258
+ else:
259
+ rng = random.Random(0x172381 ^ step)
256
260
  if 'precond_scheduler' in group:
257
261
  return utils.precond_schedule(step, group['precond_scheduler'], rng)
258
262
  if prob is not None:
@@ -415,6 +419,8 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
415
419
 
416
420
 
417
421
  class ChainOpt(utils.StatefulOptimizer):
422
+ compile_step: bool = False
423
+
418
424
  def __init__(self, params, defaults, foreach: bool, *fns):
419
425
  super().__init__(params, defaults, foreach)
420
426
  self.fns = tuple(fns)
@@ -422,27 +428,40 @@ class ChainOpt(utils.StatefulOptimizer):
422
428
  def _step(self, group):
423
429
  if 'base_lr' not in group:
424
430
  group['base_lr'] = group['lr']
425
- step = group['step'] = group.get('step', 0) + 1
426
- if group['warmup_steps'] and step < group['warmup_steps']:
427
- group['lr'] = -group['base_lr'] * step / group['warmup_steps']
428
- else:
429
- group['lr'] = -group['base_lr']
430
431
 
431
432
  vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
432
433
  if not vals:
433
434
  return
434
435
  p, g = zip(*vals)
435
436
 
437
+ for param in p:
438
+ state = self.state_(param)
439
+ if 'step' not in state:
440
+ if self.compile_step:
441
+ step = utils.scalar_guard(0, param)
442
+ state['step'] = step
443
+ step = state['step'].add_(1)
444
+ break
445
+
446
+ group['step'] = step
447
+
448
+ if group['warmup_steps'] and step < group['warmup_steps']:
449
+ group['lr'] = group['base_lr'] * step / group['warmup_steps']
450
+ else:
451
+ group['lr'] = group['base_lr']
452
+
436
453
  if not group['foreach'] or len(p) == 1:
437
454
  for param, grad in zip(p, g):
438
455
  chain(self.state_, group, [grad], [param], *self.fns)
439
- return
456
+ else:
457
+ chain(self.state_, group, g, p, *self.fns)
440
458
 
441
- chain(self.state_, group, g, p, *self.fns)
459
+ group['lr'] = None
460
+ group['step'] = None
442
461
 
443
462
 
444
463
  use_default = object()
445
- str_or_fn = Union[str, callable, None, use_default]
464
+ str_or_fn = Union[str, callable, None, Literal[use_default]]
446
465
 
447
466
 
448
467
  def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
@@ -455,15 +474,15 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
455
474
 
456
475
 
457
476
  def default(a, b):
458
- return b if a is None or a is use_default else a
477
+ return b if a is use_default else a
459
478
 
460
479
 
461
480
  # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq
462
- _scale_to_update_map = {scale_by_delayed_psgd: update_by_delayed_psgd, #
463
- scale_by_psgd: update_by_psgd, #
464
- scale_by_adam: update_by_adam, #
465
- scale_by_laprop: update_by_laprop, #
466
- scale_by_adopt: update_by_adopt}
481
+ _scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, #
482
+ scale_by_psgd.get_fn(): update_by_psgd, #
483
+ scale_by_adam.get_fn(): update_by_adam, #
484
+ scale_by_laprop.get_fn(): update_by_laprop, #
485
+ scale_by_adopt.get_fn(): update_by_adopt}
467
486
 
468
487
 
469
488
  class BaseOpt(ChainOpt):
@@ -471,16 +490,17 @@ class BaseOpt(ChainOpt):
471
490
  update_clipping: str_or_fn = None
472
491
  palm: bool = False
473
492
  auto_fuse: bool = True
474
- compile_step: bool = False
475
493
 
476
494
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
477
- palm: bool = use_default, *fns):
495
+ palm: bool = use_default, compile_step: bool = use_default, *fns):
478
496
  if default(update_clipping, self.update_clipping) is None:
479
497
  if fns and self.auto_fuse:
480
498
  args, kwargs = None, None
481
499
  fn = fns[-1]
482
500
  if isinstance(fn, functools.partial):
483
- fn, args, kwargs = fns[-1].func, fns[-1].args, fns[-1].keywords
501
+ fn, args, kwargs = fn.func, fn.args, fn.keywords
502
+ if isinstance(fn, FunctionTransform):
503
+ fn = fn.get_fn()
484
504
  if fn in _scale_to_update_map:
485
505
  fn = _scale_to_update_map[fn]
486
506
  if args is not None:
@@ -492,6 +512,7 @@ class BaseOpt(ChainOpt):
492
512
 
493
513
  fns = tuple(fns)
494
514
 
515
+ self.compile_step = default(compile_step, self.compile_step)
495
516
  if default(palm, self.palm):
496
517
  fns = (palm_beta2,) + fns
497
518
  if default(gradient_clipping, self.gradient_clipping) is not None:
heavyball/utils.py CHANGED
@@ -1,11 +1,3 @@
1
- """
2
-
3
-
4
- Originally from Evan Walters and Omead Pooladzandi, 2024
5
- Modified under Creative Commons Attribution 4.0 International
6
- Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
7
- """
8
-
9
1
  import functools
10
2
  import gc
11
3
  import math
@@ -70,16 +62,16 @@ def warmup(lr: float, step: int, warmup_steps: int):
70
62
  @decorator_knowngood
71
63
  def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, grad: List[Tensor], lr: Tensor,
72
64
  beta1: Tensor, decay: float):
73
- grad = [u_.view_as(p_) for u_, p_ in zip(grad, p)]
74
- p32, z32, g32 = [list(map(promote, x)) for x in (p, z, grad)]
75
- for p_, z_, g_ in zip(p32, z32, g32):
65
+ for op, oz, g_ in zip(p, z, grad):
66
+ g_ = g_.view_as(op)
67
+ p_, z_, g_ = map(promote, (op, oz, g_))
76
68
  if decay != 0:
77
- g_.add_(p_, alpha=decay)
78
- p_.lerp_(z_, ckp1)
79
- p_.add_(g_, alpha=lr - lr * (beta1 * (1 - ckp1)))
80
- z_.add_(g_, alpha=lr)
81
- copy_stochastic_list_(p, p32)
82
- copy_stochastic_list_(z, z32)
69
+ g_ = g_ + p_ * decay
70
+ p_ = p_.lerp(z_, ckp1)
71
+ p_ = p_ + g_ * (lr * (beta1 * (1 - ckp1)) - lr)
72
+ z_ = z_ + g_ * -lr
73
+ copy_stochastic_(op, p_)
74
+ copy_stochastic_(oz, z_)
83
75
 
84
76
 
85
77
  def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor],
@@ -164,9 +156,9 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
164
156
  out: List[Optional[Tensor]]):
165
157
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
166
158
  s32 = torch._foreach_mul(s32, beta2)
167
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
159
+ s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
168
160
  denom = torch._foreach_sqrt(s32)
169
- [d.clamp_(min=eps) for d in denom]
161
+ denom = [d.clamp(min=eps) for d in denom]
170
162
  copy_stochastic_list_(state, s32)
171
163
 
172
164
  if out[0] is None:
@@ -184,13 +176,9 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
184
176
 
185
177
  @decorator_knowngood
186
178
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
187
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
188
- s32 = torch._foreach_mul(s32, beta2)
189
- [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
190
- denom = torch._foreach_sqrt(s32)
191
- [d.clamp_(min=eps) for d in denom]
179
+ g32 = promote(grad)
180
+ denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
192
181
  out = torch._foreach_div(g32, denom)
193
- copy_stochastic_list_(state, s32)
194
182
  copy_stochastic_list_(grad, out)
195
183
 
196
184
 
@@ -201,10 +189,10 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
201
189
  return grad
202
190
 
203
191
 
192
+ # TODO: This lerp was fucked - check other lerps
204
193
  @decorator_knowngood
205
194
  def _compilable_exp_avg_(state, grad, beta):
206
- s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
- s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
195
+ s32 = [s.lerp(g, 1 - beta) for s, g in zip(promote(state), promote(grad))]
208
196
  copy_stochastic_list_(state, s32)
209
197
  copy_stochastic_list_(grad, s32)
210
198
 
@@ -218,14 +206,16 @@ def scale_by_exp_avg_(state, grad, beta):
218
206
 
219
207
  @decorator_knowngood
220
208
  def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
221
- p_norm = torch._foreach_norm(parameters)
222
- g_norm = torch._foreach_norm(gradients)
223
- torch._foreach_maximum_(p_norm, minimum)
224
- torch._foreach_maximum_(g_norm, eps)
225
- torch._foreach_div_(p_norm, g_norm)
226
- torch._foreach_mul_(p_norm, clip_val)
227
- torch._foreach_minimum_(p_norm, 1)
228
- torch._foreach_mul_(gradients, p_norm)
209
+ p32, g32 = [list(map(promote, x)) for x in (parameters, gradients)]
210
+ p_norm = torch._foreach_norm(p32)
211
+ g_norm = torch._foreach_norm(g32)
212
+ p_norm = torch._foreach_maximum(p_norm, minimum)
213
+ g_norm = torch._foreach_maximum(g_norm, eps)
214
+ p_norm = torch._foreach_div(p_norm, g_norm)
215
+ p_norm = torch._foreach_mul(p_norm, clip_val)
216
+ p_norm = torch._foreach_minimum(p_norm, 1)
217
+ g32 = torch._foreach_mul(g32, p_norm)
218
+ copy_stochastic_list_(gradients, g32)
229
219
 
230
220
 
231
221
  def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
@@ -246,10 +236,6 @@ def is_compiling():
246
236
 
247
237
 
248
238
  def set_(dst: Tensor, src: Tensor):
249
- if not is_compiling() and src.data_ptr() == dst.data_ptr():
250
- return
251
- if src.shape != dst.shape:
252
- src = src.reshape_as(dst)
253
239
  dst.copy_(src)
254
240
 
255
241
 
@@ -306,7 +292,7 @@ def ortho(x):
306
292
  def _compilable_heavyball_momentum_(state, grad, beta):
307
293
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
308
294
  s32 = torch._foreach_mul(s32, beta)
309
- torch._foreach_add_(s32, g32)
295
+ s32 = torch._foreach_add(s32, g32)
310
296
  copy_stochastic_list_(state, s32)
311
297
  copy_stochastic_list_(grad, s32)
312
298
 
@@ -315,8 +301,8 @@ def _compilable_heavyball_momentum_(state, grad, beta):
315
301
  def _compilable_nesterov_momentum_(state, grad, beta):
316
302
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
317
303
  s32 = torch._foreach_mul(s32, beta)
318
- torch._foreach_add_(s32, g32)
319
- [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
304
+ s32 = torch._foreach_add(s32, g32)
305
+ g32 = [g + s * beta for g, s in zip(g32, s32)]
320
306
  copy_stochastic_list_(state, s32)
321
307
  copy_stochastic_list_(grad, g32)
322
308
 
@@ -353,7 +339,7 @@ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
353
339
  elif scale_mode == "scale":
354
340
  y *= max(1, x.size(0) / x.size(1)) ** 0.5
355
341
  elif scale_mode == "graft":
356
- y *= x.norm() / y.norm().clamp_(min=1e-6)
342
+ y *= x.norm() / y.norm().clamp(min=1e-6)
357
343
  else:
358
344
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
359
345
  set_(out, y)
@@ -509,8 +495,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
509
495
  for x_, y_ in zip(x, y):
510
496
  x32 = promote(x_)
511
497
  y32 = promote(y_)
512
- x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
513
- copy_stochastic_(x_, x32)
498
+ copy_stochastic_(x_, x32 + y32 * alpha)
514
499
 
515
500
 
516
501
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
@@ -634,10 +619,9 @@ class StatefulOptimizer(torch.optim.Optimizer):
634
619
  def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True,
635
620
  beta1: float = -1.0):
636
621
  for p in group["params"]:
637
- if skip_none and p.grad is None:
638
- continue
639
-
640
622
  if p.grad is None:
623
+ if skip_none:
624
+ continue
641
625
  grad = None
642
626
  else:
643
627
  if should_promote:
@@ -792,7 +776,7 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
792
776
  exp_avg32 = _lerp32(exp_avg, u32, beta1)
793
777
  denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
794
778
  u32 = torch._foreach_div(exp_avg32, denom)
795
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
779
+ _compilable_update_(y, u32, decay, lr, caution, g32)
796
780
 
797
781
 
798
782
  def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
@@ -837,7 +821,7 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
837
821
  denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
838
822
  u32 = torch._foreach_div(u32, denom)
839
823
  u32 = _lerp32(exp_avg, u32, beta1)
840
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
824
+ _compilable_update_(y, u32, decay, lr, caution, gp32)
841
825
 
842
826
 
843
827
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
@@ -850,22 +834,19 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso
850
834
  @decorator_knowngood
851
835
  def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
852
836
  u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
853
- _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
837
+ _compilable_update_(y, u32, decay, lr, caution, g32)
854
838
 
855
839
  beta1 = beta_debias(beta1, step)
856
840
  denom = torch._foreach_sqrt(exp_avg_sq32)
857
- [denom.clamp_(min=eps) for denom in denom]
858
- exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
859
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
841
+ denom = [d.clamp(min=eps) for d in denom]
842
+ exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
860
843
  copy_stochastic_list_(exp_avg, exp_avg32)
861
844
 
862
845
  beta2 = beta_debias(beta2, step + 1)
863
- exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
864
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
846
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
865
847
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
866
848
 
867
849
 
868
-
869
850
  def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
870
851
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
871
852
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
@@ -879,14 +860,12 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
879
860
 
880
861
  beta1 = beta_debias(beta1, step)
881
862
  denom = torch._foreach_sqrt(exp_avg_sq32)
882
- [denom.clamp_(min=1e-8) for denom in denom]
883
- exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
884
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
863
+ denom = [d.clamp(min=1e-8) for d in denom]
864
+ exp_avg32 = [ea32.lerp(g / d, 1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
885
865
  copy_stochastic_list_(exp_avg, exp_avg32)
886
866
 
887
867
  beta2 = beta_debias(beta2, step + 1)
888
- exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
889
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
868
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
890
869
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
891
870
 
892
871
  copy_stochastic_list_(grad, update)
@@ -921,39 +900,31 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):
921
900
 
922
901
 
923
902
  def copy_stochastic_(target: Tensor, source: Tensor):
924
- if not is_compiling() and target.data_ptr() == source.data_ptr():
925
- return
926
903
  if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
927
904
  _compilable_copy_stochastic_(target, source.float())
928
905
  set_(target, source)
929
906
 
930
907
 
931
908
  @decorator_knowngood
932
- def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn: callable, lr: Tensor, caution: bool,
909
+ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool,
933
910
  g: List[Optional[Tensor]]):
934
911
  u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
935
912
  p32, u32 = [list(map(promote, x)) for x in [p, u]]
936
913
 
937
- if decay > 0:
938
- torch._foreach_mul_(p32, 1 - decay * lr)
939
-
940
- for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
914
+ for p32_, u32_, g_, p_ in zip(p32, u32, g, p): # lr is data-dependent -> can't compile a foreach
941
915
  if caution:
942
916
  u32_ = _compilable_cautioning(promote(g_), u32_)
943
- add_fn(p32_, u32_, lr)
944
-
945
- copy_stochastic_list_(p, p32)
917
+ p32_ = p32_ * (1 - decay * lr) + u32_ * -lr
918
+ copy_stochastic_(p_, p32_)
946
919
 
947
920
 
948
- def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
949
- caution: bool = False, grad: List[Tensor] = None):
921
+ def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False,
922
+ grad: List[Tensor] = None):
950
923
  param, update, grad = list_guard(param, update, grad)
951
924
  lr = scalar_guard(lr, param[0])
952
925
  if not caution:
953
926
  grad = [None] * len(param)
954
- if add_fn is None:
955
- add_fn = stochastic_add_
956
- _compilable_update_(param, update, decay, add_fn, lr, caution, grad)
927
+ _compilable_update_(param, update, decay, lr, caution, grad)
957
928
 
958
929
 
959
930
  def precond_schedule(step, precond_scheduler, rng):
@@ -1194,6 +1165,7 @@ def identity(x):
1194
1165
 
1195
1166
  @decorator_knowngood
1196
1167
  def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1168
+ # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
1197
1169
  g32 = list(map(promote, grad))
1198
1170
  [g.mul_(1 / scale) for g in g32]
1199
1171
  tanh = torch._foreach_tanh(g32)
@@ -1247,6 +1219,12 @@ def update_triu_(q_state, materialised):
1247
1219
  assert shape0 == shape1
1248
1220
  copy_stochastic_(q, m)
1249
1221
 
1222
+ _warned = set()
1223
+
1224
+ def warn_once(msg):
1225
+ if msg not in _warned:
1226
+ warnings.warn(msg)
1227
+ _warned.add(msg)
1250
1228
 
1251
1229
  def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None,
1252
1230
  name: str = 'cumulative_prob'):
@@ -1291,6 +1269,7 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1291
1269
  return new.to(ea.dtype)
1292
1270
 
1293
1271
 
1272
+ @decorator_knowngood
1294
1273
  def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1295
1274
  precond = psgd_precond_grad(expr, grad, *preconds)
1296
1275
  update_param_(param, precond, lr, decay, caution=caution, grad=grad)
@@ -1371,3 +1350,22 @@ def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1371
1350
 
1372
1351
  for p in model.parameters():
1373
1352
  p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
1353
+
1354
+
1355
+ def fused_hook(parameters, optimizer, *args, **kwargs):
1356
+ parameters = list(parameters)
1357
+ param_count = len(parameters)
1358
+ seen_params = set()
1359
+
1360
+ o = optimizer(parameters, *args, **kwargs)
1361
+
1362
+ def _step(p: Tensor):
1363
+ seen_params.add(p)
1364
+
1365
+ if len(seen_params) < param_count:
1366
+ o.step()
1367
+ o.zero_grad()
1368
+ seen_params.clear()
1369
+
1370
+ for p in parameters:
1371
+ p.register_post_accumulate_grad_hook(_step)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.2
3
+ Version: 1.2.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=5CrtSVaTI9DIgPPy0DD3WbWyVmc6-3jd2E5zM2frQlI,21092
3
+ heavyball/utils.py,sha256=H0r2GpqRS1c6qIYqW5rFYA-020AVVVWbfGne17mzlcM,47377
4
+ heavyball-1.2.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.2.0.dist-info/METADATA,sha256=YzMGNrvU_RIKGn13r8GO8kp05s9Me5PWyD3KvEd09Uo,12022
6
+ heavyball-1.2.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.2.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.2.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=Zp7q6RHYU4RgdZ_ezgc8NWPwsNfyFjRvhEK-IEqr4b4,20379
3
- heavyball/utils.py,sha256=0j5wRDYeI9Elz9m8tcP7CZNhj_9OIWEF_uQpb0LTrYM,47814
4
- heavyball-1.1.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.1.2.dist-info/METADATA,sha256=bhXVJpcuwNZaOKFydknhtqqYx0ZZsQp2wkEdUAoDfN4,12022
6
- heavyball-1.1.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.1.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.1.2.dist-info/RECORD,,