heavyball 2.0.0.dev0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +168 -29
- heavyball/chainable.py +165 -63
- heavyball/helpers.py +5 -1
- heavyball/utils.py +507 -124
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/METADATA +19 -7
- heavyball-2.1.1.dist-info/RECORD +9 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/WHEEL +1 -1
- heavyball-2.0.0.dev0.dist-info/RECORD +0 -9
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {heavyball-2.0.0.dev0.dist-info → heavyball-2.1.1.dist-info}/top_level.txt +0 -0
heavyball/chainable.py
CHANGED
@@ -2,6 +2,7 @@ import copy
|
|
2
2
|
import functools
|
3
3
|
import math
|
4
4
|
import random
|
5
|
+
from collections.abc import Iterable as _Iterable
|
5
6
|
from typing import Iterable, List, Literal, Optional, Union
|
6
7
|
|
7
8
|
import torch
|
@@ -85,6 +86,22 @@ class FunctionTransform:
|
|
85
86
|
return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
|
86
87
|
|
87
88
|
|
89
|
+
class Branch:
|
90
|
+
def __init__(self, branches: List[List[callable]], merge_fn: callable):
|
91
|
+
self.branches = branches
|
92
|
+
self.merge_fn = merge_fn
|
93
|
+
|
94
|
+
def __call__(self, state, group, update, grad, param):
|
95
|
+
outputs = []
|
96
|
+
for branch in self.branches:
|
97
|
+
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
98
|
+
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
99
|
+
if skip_update:
|
100
|
+
raise ValueError("Branches should not skip updates")
|
101
|
+
outputs.append(branch_update)
|
102
|
+
return self.merge_fn(outputs)
|
103
|
+
|
104
|
+
|
88
105
|
def _zero_guard(state, key, ref, dtype):
|
89
106
|
return _guard_in_state(state, key, lambda: torch.zeros_like(ref, dtype=dtype, memory_format=torch.preserve_format))
|
90
107
|
|
@@ -117,7 +134,7 @@ class PrecondGradAccumGuard(FunctionTransform):
|
|
117
134
|
utils.stochastic_add_(state, new)
|
118
135
|
|
119
136
|
def _reset(self, state):
|
120
|
-
if self.steps_taken
|
137
|
+
if self.steps_taken != 0:
|
121
138
|
self.steps_taken = 0
|
122
139
|
utils.zero_(state)
|
123
140
|
|
@@ -214,6 +231,23 @@ class NoStateNoForeach(FunctionTransform):
|
|
214
231
|
return updates
|
215
232
|
|
216
233
|
|
234
|
+
class SqueezeGrad(FunctionTransform):
|
235
|
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
|
236
|
+
original_shapes = [u.shape for u in update]
|
237
|
+
update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update]
|
238
|
+
grad = [x.view_as(u) for x, u in zip(grad, update)]
|
239
|
+
param = [x.view_as(u) for x, u in zip(param, update)]
|
240
|
+
args = list(args)
|
241
|
+
for i, a in enumerate(args):
|
242
|
+
if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
|
243
|
+
args[i] = [x.view_as(u) for x, u in zip(a, update)]
|
244
|
+
for k, a in kwargs.items():
|
245
|
+
if isinstance(a, (list, tuple)) and isinstance(a[0], Tensor):
|
246
|
+
kwargs[k] = [x.view_as(u) for x, u in zip(a, update)]
|
247
|
+
out = self.fn(state, group, update, grad, param, *args, **kwargs)
|
248
|
+
return [o.view(s) for o, s in zip(out, original_shapes)]
|
249
|
+
|
250
|
+
|
217
251
|
def zero_guard(*names):
|
218
252
|
return functools.partial(ZeroGuard, names=names)
|
219
253
|
|
@@ -247,11 +281,11 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
247
281
|
@copy_guard(2, "init")
|
248
282
|
@no_state
|
249
283
|
def weight_decay_to_init(group, update, grad, param, init):
|
250
|
-
utils.
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
284
|
+
utils.stochastic_lerp_(param, init, group["weight_decay_to_ema"] * group["lr"])
|
285
|
+
return update
|
286
|
+
|
287
|
+
|
288
|
+
def identity(state, group, update, grad, param):
|
255
289
|
return update
|
256
290
|
|
257
291
|
|
@@ -321,6 +355,26 @@ def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
321
355
|
raise SkipUpdate from None
|
322
356
|
|
323
357
|
|
358
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
359
|
+
@no_state
|
360
|
+
def update_by_adamc(group, update, grad, param, exp_avg, exp_avg_sq):
|
361
|
+
utils.fused_adam_(
|
362
|
+
param,
|
363
|
+
exp_avg,
|
364
|
+
exp_avg_sq,
|
365
|
+
update,
|
366
|
+
grad,
|
367
|
+
utils.get_beta1(group),
|
368
|
+
utils.get_beta2(group),
|
369
|
+
group["step"],
|
370
|
+
group["lr"],
|
371
|
+
group["eps"],
|
372
|
+
group["lr"] * group["weight_decay"] / group["max_lr"],
|
373
|
+
group["caution"],
|
374
|
+
)
|
375
|
+
raise SkipUpdate from None
|
376
|
+
|
377
|
+
|
324
378
|
@zero_guard("exp_avg", "exp_avg_sq")
|
325
379
|
@no_state
|
326
380
|
def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
@@ -426,6 +480,43 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
426
480
|
raise SkipUpdate from None
|
427
481
|
|
428
482
|
|
483
|
+
@zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
|
484
|
+
@no_state_no_foreach
|
485
|
+
def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
|
486
|
+
if group["step"] == 1:
|
487
|
+
utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
|
488
|
+
raise SkipUpdate from None
|
489
|
+
|
490
|
+
precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype))
|
491
|
+
precond_update = utils.adam_(
|
492
|
+
exp_avg,
|
493
|
+
exp_avg_sq,
|
494
|
+
precond_update.view_as(exp_avg),
|
495
|
+
utils.get_beta1(group),
|
496
|
+
utils.get_beta2(group),
|
497
|
+
group["step"] - 1,
|
498
|
+
)[0]
|
499
|
+
precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
|
500
|
+
|
501
|
+
new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
|
502
|
+
utils.copy_stochastic_(fisher_approx, new_approx)
|
503
|
+
return precond_update
|
504
|
+
|
505
|
+
|
506
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
507
|
+
@no_state
|
508
|
+
def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
509
|
+
update = utils.unscaled_adam_(
|
510
|
+
exp_avg,
|
511
|
+
exp_avg_sq,
|
512
|
+
update,
|
513
|
+
utils.get_beta1(group),
|
514
|
+
utils.get_beta2(group),
|
515
|
+
group["step"],
|
516
|
+
)
|
517
|
+
return update
|
518
|
+
|
519
|
+
|
429
520
|
@zero_guard("exp_avg", "exp_avg_sq")
|
430
521
|
@no_state
|
431
522
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
@@ -473,7 +564,8 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
|
|
473
564
|
dtype=getattr(torch, group["q_dtype"]),
|
474
565
|
)
|
475
566
|
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
|
476
|
-
state["running_lower_bound"] = [torch.zeros((), device=q.device, dtype=
|
567
|
+
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
|
568
|
+
state["step"] = torch.zeros((), device=param.device, dtype=torch.int64)
|
477
569
|
if group["adaptive"]:
|
478
570
|
state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
|
479
571
|
if not cached:
|
@@ -623,7 +715,7 @@ def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner:
|
|
623
715
|
|
624
716
|
|
625
717
|
def _update_psgd_precond(
|
626
|
-
cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, prob: Optional[callable] = None
|
718
|
+
cached, Q_cache, group, param, grad, Q, velocity, running_lower_bound, step, prob: Optional[callable] = None
|
627
719
|
) -> Optional[Tensor]:
|
628
720
|
if prob is None:
|
629
721
|
prob = utils.precond_update_prob_schedule()
|
@@ -640,13 +732,13 @@ def _update_psgd_precond(
|
|
640
732
|
else:
|
641
733
|
vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
|
642
734
|
|
643
|
-
precond =
|
735
|
+
precond = utils.psgd_update_precond(
|
644
736
|
hessian_vector,
|
645
737
|
group["precond_lr"],
|
646
738
|
Q,
|
647
739
|
group["store_triu_as_line"],
|
648
740
|
velocity,
|
649
|
-
utils.
|
741
|
+
utils.get_beta2(group),
|
650
742
|
group["ortho_method"],
|
651
743
|
vector,
|
652
744
|
running_lower_bound,
|
@@ -723,6 +815,7 @@ def _update_lra(
|
|
723
815
|
)
|
724
816
|
|
725
817
|
|
818
|
+
@SqueezeGrad
|
726
819
|
@PrecondGradAccumGuard
|
727
820
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
728
821
|
@no_state
|
@@ -731,6 +824,7 @@ def scale_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
|
731
824
|
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
732
825
|
|
733
826
|
|
827
|
+
@SqueezeGrad
|
734
828
|
@PrecondGradAccumGuard
|
735
829
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
736
830
|
@no_state
|
@@ -740,6 +834,7 @@ def update_by_psgd_lra(group, update, grad, param, update_to_precond, U, V, d):
|
|
740
834
|
raise SkipUpdate from None
|
741
835
|
|
742
836
|
|
837
|
+
@SqueezeGrad
|
743
838
|
@PrecondGradAccumGuard
|
744
839
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
745
840
|
@no_state
|
@@ -748,6 +843,7 @@ def scale_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
|
|
748
843
|
return utils.extract_from_flat_update(param, utils.lra_precond(u, v, d, utils.flatten(update)))
|
749
844
|
|
750
845
|
|
846
|
+
@SqueezeGrad
|
751
847
|
@PrecondGradAccumGuard
|
752
848
|
@general_guard("U", "V", "d", init_fn=_init_psgd_lra, skip_first=False)
|
753
849
|
@no_state
|
@@ -757,8 +853,9 @@ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
|
|
757
853
|
raise SkipUpdate from None
|
758
854
|
|
759
855
|
|
856
|
+
@SqueezeGrad
|
760
857
|
@PrecondGradAccumGuard
|
761
|
-
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
858
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
762
859
|
@no_state_no_foreach
|
763
860
|
def scale_by_psgd(
|
764
861
|
group,
|
@@ -770,15 +867,17 @@ def scale_by_psgd(
|
|
770
867
|
Q_cache,
|
771
868
|
velocity: Optional[List[Tensor]],
|
772
869
|
running_lower_bound: List[Tensor],
|
870
|
+
step: Tensor,
|
773
871
|
cached: bool = False,
|
774
872
|
prob: Optional[callable] = None,
|
775
873
|
):
|
776
|
-
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
874
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
|
777
875
|
return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
|
778
876
|
|
779
877
|
|
878
|
+
@SqueezeGrad
|
780
879
|
@PrecondGradAccumGuard
|
781
|
-
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
880
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
782
881
|
@no_state_no_foreach
|
783
882
|
def scale_by_delayed_psgd(
|
784
883
|
group,
|
@@ -790,6 +889,7 @@ def scale_by_delayed_psgd(
|
|
790
889
|
Q_cache,
|
791
890
|
velocity: Optional[List[Tensor]],
|
792
891
|
running_lower_bound: List[Tensor],
|
892
|
+
step: Tensor,
|
793
893
|
cached: bool = False,
|
794
894
|
prob: Optional[callable] = None,
|
795
895
|
):
|
@@ -797,12 +897,15 @@ def scale_by_delayed_psgd(
|
|
797
897
|
precond = None
|
798
898
|
else:
|
799
899
|
precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
|
800
|
-
new = _update_psgd_precond(
|
900
|
+
new = _update_psgd_precond(
|
901
|
+
cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob
|
902
|
+
)
|
801
903
|
return new if precond is None else precond
|
802
904
|
|
803
905
|
|
906
|
+
@SqueezeGrad
|
804
907
|
@PrecondGradAccumGuard
|
805
|
-
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
908
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
806
909
|
@no_state_no_foreach
|
807
910
|
def update_by_psgd(
|
808
911
|
group,
|
@@ -814,10 +917,11 @@ def update_by_psgd(
|
|
814
917
|
Q_cache,
|
815
918
|
velocity: Optional[List[Tensor]],
|
816
919
|
running_lower_bound: List[Tensor],
|
920
|
+
step: Tensor,
|
817
921
|
cached: bool = False,
|
818
922
|
prob: Optional[callable] = None,
|
819
923
|
):
|
820
|
-
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
924
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
|
821
925
|
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
|
822
926
|
raise SkipUpdate from None
|
823
927
|
|
@@ -833,8 +937,9 @@ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
|
|
833
937
|
return clip_fn(update)
|
834
938
|
|
835
939
|
|
940
|
+
@SqueezeGrad
|
836
941
|
@PrecondGradAccumGuard
|
837
|
-
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", init_fn=_init_psgd_kron, skip_first=False)
|
942
|
+
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
|
838
943
|
@no_state_no_foreach
|
839
944
|
def update_by_delayed_psgd(
|
840
945
|
group,
|
@@ -846,11 +951,12 @@ def update_by_delayed_psgd(
|
|
846
951
|
Q_cache,
|
847
952
|
velocity: Optional[List[Tensor]],
|
848
953
|
running_lower_bound: List[Tensor],
|
954
|
+
step: Tensor,
|
849
955
|
cached: bool = False,
|
850
956
|
prob: Optional[callable] = None,
|
851
957
|
):
|
852
958
|
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
|
853
|
-
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, prob)
|
959
|
+
_update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
|
854
960
|
raise SkipUpdate from None
|
855
961
|
|
856
962
|
|
@@ -888,55 +994,50 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
888
994
|
utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
|
889
995
|
|
890
996
|
|
891
|
-
def
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
997
|
+
def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
|
998
|
+
if retain and offset:
|
999
|
+
raise ValueError("offset cannot be retained")
|
1000
|
+
|
1001
|
+
def _walk(obj):
|
1002
|
+
stack = [obj]
|
1003
|
+
while stack:
|
1004
|
+
cur = stack.pop()
|
1005
|
+
if isinstance(cur, FunctionTransform):
|
1006
|
+
yield cur
|
1007
|
+
stack.append(cur.fn)
|
1008
|
+
elif isinstance(cur, functools.partial):
|
1009
|
+
stack.append(cur.func)
|
1010
|
+
elif isinstance(cur, Branch):
|
1011
|
+
for branch in cur.branches:
|
1012
|
+
stack.extend(branch)
|
1013
|
+
elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)):
|
1014
|
+
stack.extend(cur)
|
901
1015
|
|
902
|
-
|
1016
|
+
if retain:
|
1017
|
+
offset = max((ft.transform_idx for ft in _walk(fns) if ft.transform_idx is not None), default=-1) + 1
|
903
1018
|
|
1019
|
+
new_fns = [copy.deepcopy(fn) for fn in fns]
|
1020
|
+
for ft in _walk(new_fns):
|
1021
|
+
if not retain or ft.transform_idx is None:
|
1022
|
+
ft.transform_idx, offset = offset, offset + 1
|
904
1023
|
|
905
|
-
|
906
|
-
if retain:
|
907
|
-
if offset:
|
908
|
-
raise ValueError("offset cannot be retained")
|
909
|
-
|
910
|
-
offset = -1
|
911
|
-
for fn in fns:
|
912
|
-
while isinstance(fn, (FunctionTransform, functools.partial)):
|
913
|
-
if isinstance(fn, functools.partial):
|
914
|
-
fn = fn.func
|
915
|
-
continue
|
916
|
-
if fn.transform_idx is not None:
|
917
|
-
offset = max(offset, fn.transform_idx)
|
918
|
-
fn = fn.fn
|
919
|
-
offset += 1 # if we found nothing, this will be 0. if we found something, we START at N+1
|
920
|
-
|
921
|
-
fns = [copy.deepcopy(fn) for fn in fns]
|
922
|
-
for fn in fns:
|
923
|
-
while isinstance(fn, (FunctionTransform, functools.partial)):
|
924
|
-
if isinstance(fn, functools.partial):
|
925
|
-
fn = fn.func
|
926
|
-
continue
|
927
|
-
if not retain or fn.transform_idx is None:
|
928
|
-
fn.transform_idx = offset
|
929
|
-
offset += 1
|
930
|
-
fn = fn.fn
|
931
|
-
return fns
|
1024
|
+
return new_fns
|
932
1025
|
|
933
1026
|
|
934
1027
|
class ChainOpt(utils.StatefulOptimizer):
|
935
1028
|
promote: bool = False
|
1029
|
+
global_defaults = {
|
1030
|
+
"caution": False,
|
1031
|
+
"lr": 1,
|
1032
|
+
"warmup_steps": 0,
|
1033
|
+
"weight_decay": 0,
|
1034
|
+
"eps": 1e-8,
|
1035
|
+
}
|
936
1036
|
|
937
1037
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
938
|
-
|
939
|
-
|
1038
|
+
base = self.global_defaults.copy()
|
1039
|
+
base.update({k: v for k, v in defaults.items() if v is not use_default})
|
1040
|
+
super().__init__(params, base, foreach)
|
940
1041
|
self.fns = fns
|
941
1042
|
|
942
1043
|
@property
|
@@ -1055,7 +1156,7 @@ class BaseOpt(ChainOpt):
|
|
1055
1156
|
|
1056
1157
|
update_clipping: str_or_fn = None
|
1057
1158
|
The function to use for clipping the outgoing updates before applying them, after all other transformations.
|
1058
|
-
This will turn off
|
1159
|
+
This will turn off fused updates.
|
1059
1160
|
This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
|
1060
1161
|
|
1061
1162
|
"""
|
@@ -1069,11 +1170,11 @@ class BaseOpt(ChainOpt):
|
|
1069
1170
|
self,
|
1070
1171
|
params,
|
1071
1172
|
defaults,
|
1072
|
-
foreach: bool,
|
1073
|
-
gradient_clipping: str_or_fn,
|
1074
|
-
update_clipping: str_or_fn,
|
1173
|
+
foreach: bool = True,
|
1174
|
+
gradient_clipping: str_or_fn = None,
|
1175
|
+
update_clipping: str_or_fn = None,
|
1075
1176
|
palm: bool = use_default,
|
1076
|
-
|
1177
|
+
fns: Iterable[callable] = (),
|
1077
1178
|
compile_step: bool = use_default,
|
1078
1179
|
promote: bool = use_default,
|
1079
1180
|
):
|
@@ -1081,6 +1182,7 @@ class BaseOpt(ChainOpt):
|
|
1081
1182
|
raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`")
|
1082
1183
|
|
1083
1184
|
args, kwargs = None, None
|
1185
|
+
fns = tuple(fns)
|
1084
1186
|
fn = fns[-1]
|
1085
1187
|
if isinstance(fn, functools.partial):
|
1086
1188
|
fn, args, kwargs = fn.func, fn.args, fn.keywords
|
heavyball/helpers.py
CHANGED
@@ -420,7 +420,7 @@ class FastINGO:
|
|
420
420
|
population_size: Optional[int] = None,
|
421
421
|
learning_rate: Optional[float] = None,
|
422
422
|
last_n: int = 4096,
|
423
|
-
loco_step_size: float = 1,
|
423
|
+
loco_step_size: float = 0.1,
|
424
424
|
device="cuda",
|
425
425
|
batchnorm_decay: float = 0.99,
|
426
426
|
score_decay: float = 0.99,
|
@@ -697,6 +697,10 @@ def init_nsgaii(study, seed, trials, search_space):
|
|
697
697
|
return module.NSGAIIwITSampler(seed=seed)
|
698
698
|
|
699
699
|
|
700
|
+
def init_random(study, seed, trials, search_space):
|
701
|
+
return optuna.samplers.RandomSampler(seed=seed)
|
702
|
+
|
703
|
+
|
700
704
|
def init_ingo(study, seed, trials, search_space):
|
701
705
|
return ImplicitNaturalGradientSampler(search_space=search_space, seed=seed)
|
702
706
|
|