heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -2,6 +2,7 @@ import copy
2
2
  import functools
3
3
  import math
4
4
  import random
5
+ from collections.abc import Iterable as _Iterable
5
6
  from typing import Iterable, List, Literal, Optional, Union
6
7
 
7
8
  import torch
@@ -85,6 +86,22 @@ class FunctionTransform:
85
86
  return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
86
87
 
87
88
 
89
+ class Branch:
90
+ def __init__(self, branches: List[List[callable]], merge_fn: callable):
91
+ self.branches = branches
92
+ self.merge_fn = merge_fn
93
+
94
+ def __call__(self, state, group, update, grad, param):
95
+ outputs = []
96
+ for branch in self.branches:
97
+ branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
98
+ branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
99
+ if skip_update:
100
+ raise ValueError("Branches should not skip updates")
101
+ outputs.append(branch_update)
102
+ return self.merge_fn(outputs)
103
+
104
+
88
105
  def _zero_guard(state, key, ref, dtype):
89
106
  return _guard_in_state(state, key, lambda: torch.zeros_like(ref, dtype=dtype, memory_format=torch.preserve_format))
90
107
 
@@ -117,7 +134,7 @@ class PrecondGradAccumGuard(FunctionTransform):
117
134
  utils.stochastic_add_(state, new)
118
135
 
119
136
  def _reset(self, state):
120
- if self.steps_taken == 0:
137
+ if self.steps_taken != 0:
121
138
  self.steps_taken = 0
122
139
  utils.zero_(state)
123
140
 
@@ -214,6 +231,23 @@ class NoStateNoForeach(FunctionTransform):
214
231
  return updates
215
232
 
216
233
 
234
+ class SqueezeGrad(FunctionTransform):
235
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
236
+ original_shapes = [u.shape for u in update]
237
+ update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update]
238
+ grad = [x.view_as(u) for x, u in zip(grad, update)]
239
+ param = [x.view_as(u) for x, u in zip(param, update)]
240
+ args = list(args)
241
+ for i, a in enumerate(args):
242
+ if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
243
+ args[i] = [x.view_as(u) for x, u in zip(a, update)]
244
+ for k, a in kwargs.items():
245
+ if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
246
+ kwargs[k] = [x.view_as(u) for x, u in zip(a, update)]
247
+ out = self.fn(state, group, update, grad, param, *args, **kwargs)
248
+ return [o.view(s) for o, s in zip(out, original_shapes)]
249
+
250
+
217
251
  def zero_guard(*names):
218
252
  return functools.partial(ZeroGuard, names=names)
219
253
 
@@ -247,11 +281,11 @@ def exp_avg(group, update, grad, param, exp_avg):
247
281
  @copy_guard(2, "init")
248
282
  @no_state
249
283
  def weight_decay_to_init(group, update, grad, param, init):
250
- utils.weight_decay_to_init_(
251
- param,
252
- init,
253
- group["weight_decay_to_ema"] * group["lr"],
254
- )
284
+ utils.stochastic_lerp_(param, init, group["weight_decay_to_ema"] * group["lr"])
285
+ return update
286
+
287
+
288
+ def identity(state, group, update, grad, param):
255
289
  return update
256
290
 
257
291
 
@@ -321,6 +355,26 @@ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
321
355
  raise SkipUpdate from None
322
356
 
323
357
 
358
+ @zero_guard("exp_avg", "exp_avg_sq")
359
+ @no_state
360
+ def update_by_adamc(group, update, grad, param, exp_avg, exp_avg_sq):
361
+ utils.fused_adam_(
362
+ param,
363
+ exp_avg,
364
+ exp_avg_sq,
365
+ update,
366
+ grad,
367
+ utils.get_beta1(group),
368
+ utils.get_beta2(group),
369
+ group["step"],
370
+ group["lr"],
371
+ group["eps"],
372
+ group["lr"] * group["weight_decay"] / group["max_lr"],
373
+ group["caution"],
374
+ )
375
+ raise SkipUpdate from None
376
+
377
+
324
378
  @zero_guard("exp_avg", "exp_avg_sq")
325
379
  @no_state
326
380
  def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
@@ -426,6 +480,43 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
426
480
  raise SkipUpdate from None
427
481
 
428
482
 
483
+ @zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
484
+ @no_state_no_foreach
485
+ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
486
+ if group["step"] == 1:
487
+ utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
488
+ raise SkipUpdate from None
489
+
490
+ precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype))
491
+ precond_update = utils.adam_(
492
+ exp_avg,
493
+ exp_avg_sq,
494
+ precond_update.view_as(exp_avg),
495
+ utils.get_beta1(group),
496
+ utils.get_beta2(group),
497
+ group["step"] - 1,
498
+ )[0]
499
+ precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
500
+
501
+ new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
502
+ utils.copy_stochastic_(fisher_approx, new_approx)
503
+ return precond_update
504
+
505
+
506
+ @zero_guard("exp_avg", "exp_avg_sq")
507
+ @no_state
508
+ def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq):
509
+ update = utils.unscaled_adam_(
510
+ exp_avg,
511
+ exp_avg_sq,
512
+ update,
513
+ utils.get_beta1(group),
514
+ utils.get_beta2(group),
515
+ group["step"],
516
+ )
517
+ return update
518
+
519
+
429
520
  @zero_guard("exp_avg", "exp_avg_sq")
430
521
  @no_state
431
522
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
@@ -473,7 +564,8 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
473
564
  dtype=getattr(torch, group["q_dtype"]),
474
565
  )
475
566
  state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
476
- state["running_lower_bound"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
567
+ state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
568
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
477
569
  if group["adaptive"]:
478
570
  state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
479
571
  if not cached:
@@ -623,7 +715,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
623
715
 
624
716
 
625
717
  def _update_psgd_precond(
626
- cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, prob: Optional[callable] = None
718
+ cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, step, prob: Optional[callable] = None
627
719
  ) -> Optional[Tensor]:
628
720
  if prob is None:
629
721
  prob = utils.precond_update_prob_schedule()
@@ -640,13 +732,13 @@ def _update_psgd_precond(
640
732
  else:
641
733
  vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
642
734
 
643
- precond = (utils.inverse_free_psgd_update_precond if vector is None else utils.psgd_update_precond)(
735
+ precond = utils.psgd_update_precond(
644
736
  hessian_vector,
645
737
  group["precond_lr"],
646
738
  Q,
647
739
  group["store_triu_as_line"],
648
740
  velocity,
649
- utils.beta_debias(utils.get_beta2(group), group["step"]),
741
+ utils.get_beta2(group),
650
742
  group["ortho_method"],
651
743
  vector,
652
744
  running_lower_bound,
@@ -723,6 +815,7 @@ def _update_lra(
723
815
  )
724
816
 
725
817
 
818
+ @SqueezeGrad
726
819
  @PrecondGradAccumGuard
727
820
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
728
821
  @no_state
@@ -731,6 +824,7 @@ def scale_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
731
824
  return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
732
825
 
733
826
 
827
+ @SqueezeGrad
734
828
  @PrecondGradAccumGuard
735
829
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
736
830
  @no_state
@@ -740,6 +834,7 @@ def update_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
740
834
  raise SkipUpdate from None
741
835
 
742
836
 
837
+ @SqueezeGrad
743
838
  @PrecondGradAccumGuard
744
839
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
745
840
  @no_state
@@ -748,6 +843,7 @@ def scale_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
748
843
  return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
749
844
 
750
845
 
846
+ @SqueezeGrad
751
847
  @PrecondGradAccumGuard
752
848
  @general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
753
849
  @no_state
@@ -757,8 +853,9 @@ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
757
853
  raise SkipUpdate from None
758
854
 
759
855
 
856
+ @SqueezeGrad
760
857
  @PrecondGradAccumGuard
761
- @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
858
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
762
859
  @no_state_no_foreach
763
860
  def scale_by_psgd(
764
861
  group,
@@ -770,15 +867,17 @@ def scale_by_psgd(
770
867
  Q_cache,
771
868
  velocity: Optional[List[Tensor]],
772
869
  running_lower_bound: List[Tensor],
870
+ step: Tensor,
773
871
  cached: bool = False,
774
872
  prob: Optional[callable] = None,
775
873
  ):
776
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
874
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
777
875
  return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
778
876
 
779
877
 
878
+ @SqueezeGrad
780
879
  @PrecondGradAccumGuard
781
- @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
880
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
782
881
  @no_state_no_foreach
783
882
  def scale_by_delayed_psgd(
784
883
  group,
@@ -790,6 +889,7 @@ def scale_by_delayed_psgd(
790
889
  Q_cache,
791
890
  velocity: Optional[List[Tensor]],
792
891
  running_lower_bound: List[Tensor],
892
+ step: Tensor,
793
893
  cached: bool = False,
794
894
  prob: Optional[callable] = None,
795
895
  ):
@@ -797,12 +897,15 @@ def scale_by_delayed_psgd(
797
897
  precond = None
798
898
  else:
799
899
  precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
800
- new = _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
900
+ new = _update_psgd_precond(
901
+ cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob
902
+ )
801
903
  return new if precond is None else precond
802
904
 
803
905
 
906
+ @SqueezeGrad
804
907
  @PrecondGradAccumGuard
805
- @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
908
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
806
909
  @no_state_no_foreach
807
910
  def update_by_psgd(
808
911
  group,
@@ -814,10 +917,11 @@ def update_by_psgd(
814
917
  Q_cache,
815
918
  velocity: Optional[List[Tensor]],
816
919
  running_lower_bound: List[Tensor],
920
+ step: Tensor,
817
921
  cached: bool = False,
818
922
  prob: Optional[callable] = None,
819
923
  ):
820
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
924
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
821
925
  _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
822
926
  raise SkipUpdate from None
823
927
 
@@ -833,8 +937,9 @@ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
833
937
  return clip_fn(update)
834
938
 
835
939
 
940
+ @SqueezeGrad
836
941
  @PrecondGradAccumGuard
837
- @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
942
+ @general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
838
943
  @no_state_no_foreach
839
944
  def update_by_delayed_psgd(
840
945
  group,
@@ -846,11 +951,12 @@ def update_by_delayed_psgd(
846
951
  Q_cache,
847
952
  velocity: Optional[List[Tensor]],
848
953
  running_lower_bound: List[Tensor],
954
+ step: Tensor,
849
955
  cached: bool = False,
850
956
  prob: Optional[callable] = None,
851
957
  ):
852
958
  _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
853
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
959
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
854
960
  raise SkipUpdate from None
855
961
 
856
962
 
@@ -888,55 +994,50 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
888
994
  utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
889
995
 
890
996
 
891
- def create_branch(branches: List[List[callable]], merge_fn: callable):
892
- def _branch(state, group, update, grad, param):
893
- outputs = []
894
- for branch in branches:
895
- branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
896
- branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
897
- if skip_update:
898
- raise ValueError("Branches should not skip updates")
899
- outputs.append(branch_update)
900
- return merge_fn(outputs)
997
+ def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
998
+ if retain and offset:
999
+ raise ValueError("offset cannot be retained")
1000
+
1001
+ def _walk(obj):
1002
+ stack = [obj]
1003
+ while stack:
1004
+ cur = stack.pop()
1005
+ if isinstance(cur, FunctionTransform):
1006
+ yield cur
1007
+ stack.append(cur.fn)
1008
+ elif isinstance(cur, functools.partial):
1009
+ stack.append(cur.func)
1010
+ elif isinstance(cur, Branch):
1011
+ for branch in cur.branches:
1012
+ stack.extend(branch)
1013
+ elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)):
1014
+ stack.extend(cur)
901
1015
 
902
- return _branch
1016
+ if retain:
1017
+ offset = max((ft.transform_idx for ft in _walk(fns) if ft.transform_idx is not None), default=-1) + 1
903
1018
 
1019
+ new_fns = [copy.deepcopy(fn) for fn in fns]
1020
+ for ft in _walk(new_fns):
1021
+ if not retain or ft.transform_idx is None:
1022
+ ft.transform_idx, offset = offset, offset + 1
904
1023
 
905
- def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
906
- if retain:
907
- if offset:
908
- raise ValueError("offset cannot be retained")
909
-
910
- offset = -1
911
- for fn in fns:
912
- while isinstance(fn, (FunctionTransform, functools.partial)):
913
- if isinstance(fn, functools.partial):
914
- fn = fn.func
915
- continue
916
- if fn.transform_idx is not None:
917
- offset = max(offset, fn.transform_idx)
918
- fn = fn.fn
919
- offset += 1 # if we found nothing, this will be 0. if we found something, we START at N+1
920
-
921
- fns = [copy.deepcopy(fn) for fn in fns]
922
- for fn in fns:
923
- while isinstance(fn, (FunctionTransform, functools.partial)):
924
- if isinstance(fn, functools.partial):
925
- fn = fn.func
926
- continue
927
- if not retain or fn.transform_idx is None:
928
- fn.transform_idx = offset
929
- offset += 1
930
- fn = fn.fn
931
- return fns
1024
+ return new_fns
932
1025
 
933
1026
 
934
1027
  class ChainOpt(utils.StatefulOptimizer):
935
1028
  promote: bool = False
1029
+ global_defaults = {
1030
+ "caution": False,
1031
+ "lr": 1,
1032
+ "warmup_steps": 0,
1033
+ "weight_decay": 0,
1034
+ "eps": 1e-8,
1035
+ }
936
1036
 
937
1037
  def __init__(self, params, defaults, foreach: bool, *fns):
938
- defaults = {k: v for k, v in defaults.items() if v is not use_default}
939
- super().__init__(params, defaults, foreach)
1038
+ base = self.global_defaults.copy()
1039
+ base.update({k: v for k, v in defaults.items() if v is not use_default})
1040
+ super().__init__(params, base, foreach)
940
1041
  self.fns = fns
941
1042
 
942
1043
  @property
@@ -1055,7 +1156,7 @@ class BaseOpt(ChainOpt):
1055
1156
 
1056
1157
  update_clipping: str_or_fn = None
1057
1158
  The function to use for clipping the outgoing updates before applying them, after all other transformations.
1058
- This will turn off
1159
+ This will turn off fused updates.
1059
1160
  This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
1060
1161
 
1061
1162
  """
@@ -1069,11 +1170,11 @@ class BaseOpt(ChainOpt):
1069
1170
  self,
1070
1171
  params,
1071
1172
  defaults,
1072
- foreach: bool,
1073
- gradient_clipping: str_or_fn,
1074
- update_clipping: str_or_fn,
1173
+ foreach: bool = True,
1174
+ gradient_clipping: str_or_fn = None,
1175
+ update_clipping: str_or_fn = None,
1075
1176
  palm: bool = use_default,
1076
- *fns,
1177
+ fns: Iterable[callable] = (),
1077
1178
  compile_step: bool = use_default,
1078
1179
  promote: bool = use_default,
1079
1180
  ):
@@ -1081,6 +1182,7 @@ class BaseOpt(ChainOpt):
1081
1182
  raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
1082
1183
 
1083
1184
  args, kwargs = None, None
1185
+ fns = tuple(fns)
1084
1186
  fn = fns[-1]
1085
1187
  if isinstance(fn, functools.partial):
1086
1188
  fn, args, kwargs = fn.func, fn.args, fn.keywords
heavyball/helpers.py CHANGED
@@ -420,7 +420,7 @@ class FastINGO:
420
420
  population_size: Optional[int] = None,
421
421
  learning_rate: Optional[float] = None,
422
422
  last_n: int = 4096,
423
- loco_step_size: float = 1,
423
+ loco_step_size: float = 0.1,
424
424
  device="cuda",
425
425
  batchnorm_decay: float = 0.99,
426
426
  score_decay: float = 0.99,
@@ -697,6 +697,10 @@ def init_nsgaii(study, seed, trials, search_space):
697
697
  return module.NSGAIIwITSampler(seed=seed)
698
698
 
699
699
 
700
+ def init_random(study, seed, trials, search_space):
701
+ return optuna.samplers.RandomSampler(seed=seed)
702
+
703
+
700
704
  def init_ingo(study, seed, trials, search_space):
701
705
  return ImplicitNaturalGradientSampler(search_space=search_space, seed=seed)
702
706