torchzero 0.3.13__tar.gz → 0.3.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.13 → torchzero-0.3.14}/PKG-INFO +1 -1
- {torchzero-0.3.13 → torchzero-0.3.14}/pyproject.toml +1 -1
- {torchzero-0.3.13 → torchzero-0.3.14}/tests/test_opts.py +0 -7
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/core/module.py +4 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/conjugate_gradient/cg.py +16 -16
- torchzero-0.3.14/torchzero/modules/experimental/spsa1.py +93 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/grad_approximation/__init__.py +1 -1
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/grad_approximation/forward_gradient.py +2 -5
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/grad_approximation/rfdm.py +27 -110
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/scipy.py +15 -3
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/strong_wolfe.py +0 -2
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/restarts/restars.py +5 -4
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/second_order/newton_cg.py +86 -110
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/trust_region/trust_cg.py +6 -4
- torchzero-0.3.14/torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero-0.3.14/torchzero/modules/zeroth_order/cd.py +122 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/optimizer.py +2 -2
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/python_tools.py +1 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero.egg-info/PKG-INFO +1 -1
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero.egg-info/SOURCES.txt +1 -0
- torchzero-0.3.13/torchzero/modules/zeroth_order/__init__.py +0 -1
- torchzero-0.3.13/torchzero/modules/zeroth_order/cd.py +0 -359
- {torchzero-0.3.13 → torchzero-0.3.14}/setup.cfg +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/tests/test_identical.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/tests/test_module.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/tests/test_tensorlist.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/tests/test_vars.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/core/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/core/reformulation.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/core/transform.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/adagrad.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/adahessian.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/adam.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/adan.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/aegd.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/esgd.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/lion.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/lmadagrad.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/mars.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/matrix_momentum.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/msam.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/muon.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/natural_gradient.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/orthograd.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/rmsprop.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/rprop.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/sam.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/shampoo.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/soap.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/adaptive/sophia_h.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/dct.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/fft.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/l_infinity.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/momentum.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/newton_solver.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/newtonnewton.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/experimental/structural_projections.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/functional.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/higher_order/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/higher_order/higher_order_newton.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/least_squares/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/least_squares/gn.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/_polyinterp.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/adaptive.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/debug.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/escape.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/gradient_accumulation.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/homotopy.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/misc.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/multistep.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/regularization.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/split.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/misc/switch.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/momentum/momentum.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/higher_level.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/projections/cast.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/projections/projection.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/quasi_newton/damping.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/quasi_newton/lsr1.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/restarts/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/second_order/multipoint.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/second_order/newton.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/second_order/nystrom.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/smoothing/sampling.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/step_size/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/step_size/adaptive.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/step_size/lr.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/termination/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/termination/termination.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/trust_region/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/trust_region/dogleg.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/trust_region/trust_region.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/variance_reduction/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/variance_reduction/svrg.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/weight_decay/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/weight_decay/weight_decay.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/root.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/directsearch.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/fcmaes.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/mads.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/optuna.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/optim/wrappers/scipy.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/__init__.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/benchmark.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/linear_operator.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/solve.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/metrics.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.13 → torchzero-0.3.14}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -400,13 +400,6 @@ RandomizedFDM_4samples = Run(
|
|
|
400
400
|
func='booth', steps=50, loss=1e-5, merge_invariant=True,
|
|
401
401
|
sphere_steps=100, sphere_loss=400,
|
|
402
402
|
)
|
|
403
|
-
RandomizedFDM_4samples_lerp = Run(
|
|
404
|
-
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.99, seed=0), tz.m.LR(0.1)),
|
|
405
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.9, seed=0), tz.m.LR(0.001)),
|
|
406
|
-
needs_closure=True,
|
|
407
|
-
func='booth', steps=50, loss=1e-5, merge_invariant=True,
|
|
408
|
-
sphere_steps=100, sphere_loss=505,
|
|
409
|
-
)
|
|
410
403
|
RandomizedFDM_4samples_no_pre_generate = Run(
|
|
411
404
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
|
|
412
405
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
|
|
@@ -531,7 +531,11 @@ class Module(ABC):
|
|
|
531
531
|
def reset(self):
|
|
532
532
|
"""Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
|
|
533
533
|
self.state.clear()
|
|
534
|
+
|
|
535
|
+
generator = self.global_state.get("generator", None)
|
|
534
536
|
self.global_state.clear()
|
|
537
|
+
if generator is not None: self.global_state["generator"] = generator
|
|
538
|
+
|
|
535
539
|
for c in self.children.values(): c.reset()
|
|
536
540
|
|
|
537
541
|
def reset_for_online(self):
|
|
@@ -50,7 +50,7 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
50
50
|
```
|
|
51
51
|
|
|
52
52
|
"""
|
|
53
|
-
def __init__(self, defaults
|
|
53
|
+
def __init__(self, defaults, clip_beta: bool, restart_interval: int | None | Literal['auto'], inner: Chainable | None = None):
|
|
54
54
|
if defaults is None: defaults = {}
|
|
55
55
|
defaults['restart_interval'] = restart_interval
|
|
56
56
|
defaults['clip_beta'] = clip_beta
|
|
@@ -140,8 +140,8 @@ class PolakRibiere(ConguateGradientBase):
|
|
|
140
140
|
Note:
|
|
141
141
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
142
142
|
"""
|
|
143
|
-
def __init__(self, clip_beta=True, restart_interval: int | None =
|
|
144
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
143
|
+
def __init__(self, clip_beta=True, restart_interval: int | None | Literal['auto'] = 'auto', inner: Chainable | None = None):
|
|
144
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
145
145
|
|
|
146
146
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
147
147
|
return polak_ribiere_beta(g, prev_g)
|
|
@@ -158,7 +158,7 @@ class FletcherReeves(ConguateGradientBase):
|
|
|
158
158
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
159
159
|
"""
|
|
160
160
|
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
161
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
161
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
162
162
|
|
|
163
163
|
def initialize(self, p, g):
|
|
164
164
|
self.global_state['prev_gg'] = g.dot(g)
|
|
@@ -183,8 +183,8 @@ class HestenesStiefel(ConguateGradientBase):
|
|
|
183
183
|
Note:
|
|
184
184
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
185
185
|
"""
|
|
186
|
-
def __init__(self, restart_interval: int | None | Literal['auto'] =
|
|
187
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
186
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
187
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
188
188
|
|
|
189
189
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
190
190
|
return hestenes_stiefel_beta(g, prev_d, prev_g)
|
|
@@ -202,8 +202,8 @@ class DaiYuan(ConguateGradientBase):
|
|
|
202
202
|
Note:
|
|
203
203
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
|
|
204
204
|
"""
|
|
205
|
-
def __init__(self, restart_interval: int | None | Literal['auto'] =
|
|
206
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
205
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
206
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
207
207
|
|
|
208
208
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
209
209
|
return dai_yuan_beta(g, prev_d, prev_g)
|
|
@@ -221,8 +221,8 @@ class LiuStorey(ConguateGradientBase):
|
|
|
221
221
|
Note:
|
|
222
222
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
223
223
|
"""
|
|
224
|
-
def __init__(self, restart_interval: int | None | Literal['auto'] =
|
|
225
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
224
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
225
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
226
226
|
|
|
227
227
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
228
228
|
return liu_storey_beta(g, prev_d, prev_g)
|
|
@@ -239,8 +239,8 @@ class ConjugateDescent(ConguateGradientBase):
|
|
|
239
239
|
Note:
|
|
240
240
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
241
241
|
"""
|
|
242
|
-
def __init__(self, restart_interval: int | None | Literal['auto'] =
|
|
243
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
242
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
243
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
244
244
|
|
|
245
245
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
246
246
|
return conjugate_descent_beta(g, prev_d, prev_g)
|
|
@@ -264,8 +264,8 @@ class HagerZhang(ConguateGradientBase):
|
|
|
264
264
|
Note:
|
|
265
265
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
266
266
|
"""
|
|
267
|
-
def __init__(self, restart_interval: int | None | Literal['auto'] =
|
|
268
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
267
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
268
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
269
269
|
|
|
270
270
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
271
271
|
return hager_zhang_beta(g, prev_d, prev_g)
|
|
@@ -291,8 +291,8 @@ class DYHS(ConguateGradientBase):
|
|
|
291
291
|
Note:
|
|
292
292
|
This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
|
|
293
293
|
"""
|
|
294
|
-
def __init__(self, restart_interval: int | None | Literal['auto'] =
|
|
295
|
-
super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
294
|
+
def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
|
|
295
|
+
super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
|
|
296
296
|
|
|
297
297
|
def get_beta(self, p, g, prev_g, prev_d):
|
|
298
298
|
return dyhs_beta(g, prev_d, prev_g)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
from functools import partial
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList, NumberList
|
|
7
|
+
from ..grad_approximation.grad_approximator import GradApproximator, GradTarget
|
|
8
|
+
|
|
9
|
+
class SPSA1(GradApproximator):
|
|
10
|
+
"""One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated
|
|
11
|
+
gradient often won't be a descent direction, however the expectation is biased towards
|
|
12
|
+
the descent direction. Therefore this variant of SPSA is only recommended for a specific
|
|
13
|
+
class of problems where the objective function changes on each evaluation,
|
|
14
|
+
for example feedback control problems.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
h (float, optional):
|
|
18
|
+
finite difference step size, recommended to set to same value as learning rate. Defaults to 1e-3.
|
|
19
|
+
n_samples (int, optional): number of random samples. Defaults to 1.
|
|
20
|
+
eps (float, optional): measurement noise estimate. Defaults to 1e-8.
|
|
21
|
+
seed (int | None | torch.Generator, optional): random seed. Defaults to None.
|
|
22
|
+
target (GradTarget, optional): what to set on closure. Defaults to "closure".
|
|
23
|
+
|
|
24
|
+
Reference:
|
|
25
|
+
[SPALL, JAMES C. "A One-measurement Form of Simultaneous Stochastic Approximation](https://www.jhuapl.edu/spsa/PDF-SPSA/automatica97_one_measSPSA.pdf)."
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
h: float = 1e-3,
|
|
31
|
+
n_samples: int = 1,
|
|
32
|
+
eps: float = 1e-8, # measurement noise
|
|
33
|
+
pre_generate = False,
|
|
34
|
+
seed: int | None | torch.Generator = None,
|
|
35
|
+
target: GradTarget = "closure",
|
|
36
|
+
):
|
|
37
|
+
defaults = dict(h=h, eps=eps, n_samples=n_samples, pre_generate=pre_generate, seed=seed)
|
|
38
|
+
super().__init__(defaults, target=target)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pre_step(self, var):
|
|
42
|
+
|
|
43
|
+
if self.defaults['pre_generate']:
|
|
44
|
+
|
|
45
|
+
params = TensorList(var.params)
|
|
46
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
47
|
+
|
|
48
|
+
n_samples = self.defaults['n_samples']
|
|
49
|
+
h = self.get_settings(var.params, 'h')
|
|
50
|
+
|
|
51
|
+
perturbations = [params.sample_like(distribution='rademacher', generator=generator) for _ in range(n_samples)]
|
|
52
|
+
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
53
|
+
|
|
54
|
+
for param, prt in zip(params, zip(*perturbations)):
|
|
55
|
+
self.state[param]['perturbations'] = prt
|
|
56
|
+
|
|
57
|
+
@torch.no_grad
|
|
58
|
+
def approximate(self, closure, params, loss):
|
|
59
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
60
|
+
|
|
61
|
+
params = TensorList(params)
|
|
62
|
+
orig_params = params.clone() # store to avoid small changes due to float imprecision
|
|
63
|
+
loss_approx = None
|
|
64
|
+
|
|
65
|
+
h, eps = self.get_settings(params, "h", "eps", cls=NumberList)
|
|
66
|
+
n_samples = self.defaults['n_samples']
|
|
67
|
+
|
|
68
|
+
default = [None]*n_samples
|
|
69
|
+
# perturbations are pre-multiplied by h
|
|
70
|
+
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
71
|
+
|
|
72
|
+
grad = None
|
|
73
|
+
for i in range(n_samples):
|
|
74
|
+
prt = perturbations[i]
|
|
75
|
+
|
|
76
|
+
if prt[0] is None:
|
|
77
|
+
prt = params.sample_like('rademacher', generator=generator).mul_(h)
|
|
78
|
+
|
|
79
|
+
else: prt = TensorList(prt)
|
|
80
|
+
|
|
81
|
+
params += prt
|
|
82
|
+
L = closure(False)
|
|
83
|
+
params.copy_(orig_params)
|
|
84
|
+
|
|
85
|
+
sample = prt * ((L + eps) / h)
|
|
86
|
+
if grad is None: grad = sample
|
|
87
|
+
else: grad += sample
|
|
88
|
+
|
|
89
|
+
assert grad is not None
|
|
90
|
+
if n_samples > 1: grad.div_(n_samples)
|
|
91
|
+
|
|
92
|
+
# mean if got per-sample values
|
|
93
|
+
return grad, loss, loss_approx
|
{torchzero-0.3.13 → torchzero-0.3.14}/torchzero/modules/grad_approximation/forward_gradient.py
RENAMED
|
@@ -23,8 +23,6 @@ class ForwardGradient(RandomizedFDM):
|
|
|
23
23
|
Args:
|
|
24
24
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
25
25
|
distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
|
|
26
|
-
beta (float, optional):
|
|
27
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
28
26
|
pre_generate (bool, optional):
|
|
29
27
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
30
28
|
jvp_method (str, optional):
|
|
@@ -40,14 +38,13 @@ class ForwardGradient(RandomizedFDM):
|
|
|
40
38
|
self,
|
|
41
39
|
n_samples: int = 1,
|
|
42
40
|
distribution: Distributions = "gaussian",
|
|
43
|
-
beta: float = 0,
|
|
44
41
|
pre_generate = True,
|
|
45
42
|
jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
46
43
|
h: float = 1e-3,
|
|
47
44
|
target: GradTarget = "closure",
|
|
48
45
|
seed: int | None | torch.Generator = None,
|
|
49
46
|
):
|
|
50
|
-
super().__init__(h=h, n_samples=n_samples, distribution=distribution,
|
|
47
|
+
super().__init__(h=h, n_samples=n_samples, distribution=distribution, target=target, pre_generate=pre_generate, seed=seed)
|
|
51
48
|
self.defaults['jvp_method'] = jvp_method
|
|
52
49
|
|
|
53
50
|
@torch.no_grad
|
|
@@ -62,7 +59,7 @@ class ForwardGradient(RandomizedFDM):
|
|
|
62
59
|
distribution = settings['distribution']
|
|
63
60
|
default = [None]*n_samples
|
|
64
61
|
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
65
|
-
generator = self.
|
|
62
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
66
63
|
|
|
67
64
|
grad = None
|
|
68
65
|
for i in range(n_samples):
|
|
@@ -164,7 +164,6 @@ class RandomizedFDM(GradApproximator):
|
|
|
164
164
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
165
165
|
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
166
166
|
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
167
|
-
beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
|
|
168
167
|
pre_generate (bool, optional):
|
|
169
168
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
170
169
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -173,7 +172,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
173
172
|
Examples:
|
|
174
173
|
#### Simultaneous perturbation stochastic approximation (SPSA) method
|
|
175
174
|
|
|
176
|
-
SPSA is randomized
|
|
175
|
+
SPSA is randomized FDM with rademacher distribution and central formula.
|
|
177
176
|
```py
|
|
178
177
|
spsa = tz.Modular(
|
|
179
178
|
model.parameters(),
|
|
@@ -184,8 +183,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
184
183
|
|
|
185
184
|
#### Random-direction stochastic approximation (RDSA) method
|
|
186
185
|
|
|
187
|
-
RDSA is randomized
|
|
188
|
-
|
|
186
|
+
RDSA is randomized FDM with usually gaussian distribution and central formula.
|
|
189
187
|
```
|
|
190
188
|
rdsa = tz.Modular(
|
|
191
189
|
model.parameters(),
|
|
@@ -194,23 +192,9 @@ class RandomizedFDM(GradApproximator):
|
|
|
194
192
|
)
|
|
195
193
|
```
|
|
196
194
|
|
|
197
|
-
#### RandomizedFDM with momentum
|
|
198
|
-
|
|
199
|
-
Momentum might help by reducing the variance of the estimated gradients.
|
|
200
|
-
|
|
201
|
-
```
|
|
202
|
-
momentum_spsa = tz.Modular(
|
|
203
|
-
model.parameters(),
|
|
204
|
-
tz.m.RandomizedFDM(),
|
|
205
|
-
tz.m.HeavyBall(0.9),
|
|
206
|
-
tz.m.LR(1e-3)
|
|
207
|
-
)
|
|
208
|
-
```
|
|
209
|
-
|
|
210
195
|
#### Gaussian smoothing method
|
|
211
196
|
|
|
212
197
|
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
213
|
-
|
|
214
198
|
```
|
|
215
199
|
gs = tz.Modular(
|
|
216
200
|
model.parameters(),
|
|
@@ -220,44 +204,15 @@ class RandomizedFDM(GradApproximator):
|
|
|
220
204
|
)
|
|
221
205
|
```
|
|
222
206
|
|
|
223
|
-
####
|
|
224
|
-
|
|
225
|
-
NewtonCG with hessian-vector product estimated via gradient difference
|
|
226
|
-
calls closure multiple times per step. If each closure call estimates gradients
|
|
227
|
-
with different perturbations, NewtonCG is unable to produce useful directions.
|
|
228
|
-
|
|
229
|
-
By setting pre_generate to True, perturbations are generated once before each step,
|
|
230
|
-
and each closure call estimates gradients using the same pre-generated perturbations.
|
|
231
|
-
This way closure-based algorithms are able to use gradients estimated in a consistent way.
|
|
207
|
+
#### RandomizedFDM with momentum
|
|
232
208
|
|
|
209
|
+
Momentum might help by reducing the variance of the estimated gradients.
|
|
233
210
|
```
|
|
234
|
-
|
|
211
|
+
momentum_spsa = tz.Modular(
|
|
235
212
|
model.parameters(),
|
|
236
|
-
tz.m.RandomizedFDM(
|
|
237
|
-
tz.m.
|
|
238
|
-
tz.m.
|
|
239
|
-
)
|
|
240
|
-
```
|
|
241
|
-
|
|
242
|
-
#### SPSA-LBFGS
|
|
243
|
-
|
|
244
|
-
LBFGS uses a memory of past parameter and gradient differences. If past gradients
|
|
245
|
-
were estimated with different perturbations, LBFGS directions will be useless.
|
|
246
|
-
|
|
247
|
-
To alleviate this momentum can be added to random perturbations to make sure they only
|
|
248
|
-
change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
|
|
249
|
-
The disadvantage is that the subspace the algorithm is able to explore changes slowly.
|
|
250
|
-
|
|
251
|
-
Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.
|
|
252
|
-
|
|
253
|
-
```
|
|
254
|
-
opt = tz.Modular(
|
|
255
|
-
bench.parameters(),
|
|
256
|
-
tz.m.ResetEvery(
|
|
257
|
-
[tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
|
|
258
|
-
steps = 100,
|
|
259
|
-
),
|
|
260
|
-
tz.m.Backtracking()
|
|
213
|
+
tz.m.RandomizedFDM(),
|
|
214
|
+
tz.m.HeavyBall(0.9),
|
|
215
|
+
tz.m.LR(1e-3)
|
|
261
216
|
)
|
|
262
217
|
```
|
|
263
218
|
"""
|
|
@@ -268,75 +223,46 @@ class RandomizedFDM(GradApproximator):
|
|
|
268
223
|
n_samples: int = 1,
|
|
269
224
|
formula: _FD_Formula = "central",
|
|
270
225
|
distribution: Distributions = "rademacher",
|
|
271
|
-
beta: float = 0,
|
|
272
226
|
pre_generate = True,
|
|
273
227
|
seed: int | None | torch.Generator = None,
|
|
274
228
|
target: GradTarget = "closure",
|
|
275
229
|
):
|
|
276
|
-
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution,
|
|
230
|
+
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
|
|
277
231
|
super().__init__(defaults, target=target)
|
|
278
232
|
|
|
279
|
-
def reset(self):
|
|
280
|
-
self.state.clear()
|
|
281
|
-
generator = self.global_state.get('generator', None) # avoid resetting generator
|
|
282
|
-
self.global_state.clear()
|
|
283
|
-
if generator is not None: self.global_state['generator'] = generator
|
|
284
|
-
for c in self.children.values(): c.reset()
|
|
285
|
-
|
|
286
|
-
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
287
|
-
if 'generator' not in self.global_state:
|
|
288
|
-
if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
|
|
289
|
-
elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
290
|
-
else: self.global_state['generator'] = None
|
|
291
|
-
return self.global_state['generator']
|
|
292
233
|
|
|
293
234
|
def pre_step(self, var):
|
|
294
|
-
h
|
|
295
|
-
|
|
296
|
-
n_samples = self.defaults['n_samples']
|
|
297
|
-
distribution = self.defaults['distribution']
|
|
235
|
+
h = self.get_settings(var.params, 'h')
|
|
298
236
|
pre_generate = self.defaults['pre_generate']
|
|
299
237
|
|
|
300
238
|
if pre_generate:
|
|
239
|
+
n_samples = self.defaults['n_samples']
|
|
240
|
+
distribution = self.defaults['distribution']
|
|
241
|
+
|
|
301
242
|
params = TensorList(var.params)
|
|
302
|
-
generator = self.
|
|
243
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
303
244
|
perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
|
|
304
245
|
|
|
246
|
+
# this is false for ForwardGradient where h isn't used and it subclasses this
|
|
305
247
|
if self.PRE_MULTIPLY_BY_H:
|
|
306
248
|
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
307
249
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
for param, prt in zip(params, zip(*perturbations)):
|
|
311
|
-
self.state[param]['perturbations'] = prt
|
|
312
|
-
|
|
313
|
-
else:
|
|
314
|
-
# lerp old and new perturbations. This makes the subspace change gradually
|
|
315
|
-
# which in theory might improve algorithms with history
|
|
316
|
-
for i,p in enumerate(params):
|
|
317
|
-
state = self.state[p]
|
|
318
|
-
if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]
|
|
319
|
-
|
|
320
|
-
cur = [self.state[p]['perturbations'][:n_samples] for p in params]
|
|
321
|
-
cur_flat = [p for l in cur for p in l]
|
|
322
|
-
new_flat = [p for l in zip(*perturbations) for p in l]
|
|
323
|
-
betas = [1-v for b in beta for v in [b]*n_samples]
|
|
324
|
-
torch._foreach_lerp_(cur_flat, new_flat, betas)
|
|
250
|
+
for param, prt in zip(params, zip(*perturbations)):
|
|
251
|
+
self.state[param]['perturbations'] = prt
|
|
325
252
|
|
|
326
253
|
@torch.no_grad
|
|
327
254
|
def approximate(self, closure, params, loss):
|
|
328
255
|
params = TensorList(params)
|
|
329
|
-
orig_params = params.clone() # store to avoid small changes due to float imprecision
|
|
330
256
|
loss_approx = None
|
|
331
257
|
|
|
332
258
|
h = NumberList(self.settings[p]['h'] for p in params)
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
fd_fn = _RFD_FUNCS[
|
|
259
|
+
n_samples = self.defaults['n_samples']
|
|
260
|
+
distribution = self.defaults['distribution']
|
|
261
|
+
fd_fn = _RFD_FUNCS[self.defaults['formula']]
|
|
262
|
+
|
|
336
263
|
default = [None]*n_samples
|
|
337
264
|
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
338
|
-
|
|
339
|
-
generator = self._get_generator(settings['seed'], params)
|
|
265
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
340
266
|
|
|
341
267
|
grad = None
|
|
342
268
|
for i in range(n_samples):
|
|
@@ -356,7 +282,6 @@ class RandomizedFDM(GradApproximator):
|
|
|
356
282
|
if grad is None: grad = prt * d
|
|
357
283
|
else: grad += prt * d
|
|
358
284
|
|
|
359
|
-
params.set_(orig_params)
|
|
360
285
|
assert grad is not None
|
|
361
286
|
if n_samples > 1: grad.div_(n_samples)
|
|
362
287
|
|
|
@@ -384,8 +309,6 @@ class SPSA(RandomizedFDM):
|
|
|
384
309
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
385
310
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
386
311
|
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
387
|
-
beta (float, optional):
|
|
388
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
389
312
|
pre_generate (bool, optional):
|
|
390
313
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
391
314
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -408,8 +331,6 @@ class RDSA(RandomizedFDM):
|
|
|
408
331
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
409
332
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
410
333
|
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
411
|
-
beta (float, optional):
|
|
412
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
413
334
|
pre_generate (bool, optional):
|
|
414
335
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
415
336
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -425,12 +346,11 @@ class RDSA(RandomizedFDM):
|
|
|
425
346
|
n_samples: int = 1,
|
|
426
347
|
formula: _FD_Formula = "central2",
|
|
427
348
|
distribution: Distributions = "gaussian",
|
|
428
|
-
beta: float = 0,
|
|
429
349
|
pre_generate = True,
|
|
430
350
|
target: GradTarget = "closure",
|
|
431
351
|
seed: int | None | torch.Generator = None,
|
|
432
352
|
):
|
|
433
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,
|
|
353
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
434
354
|
|
|
435
355
|
class GaussianSmoothing(RandomizedFDM):
|
|
436
356
|
"""
|
|
@@ -445,8 +365,6 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
445
365
|
n_samples (int, optional): number of random gradient samples. Defaults to 100.
|
|
446
366
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
|
|
447
367
|
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
448
|
-
beta (float, optional):
|
|
449
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
450
368
|
pre_generate (bool, optional):
|
|
451
369
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
452
370
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -462,12 +380,11 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
462
380
|
n_samples: int = 100,
|
|
463
381
|
formula: _FD_Formula = "forward2",
|
|
464
382
|
distribution: Distributions = "gaussian",
|
|
465
|
-
beta: float = 0,
|
|
466
383
|
pre_generate = True,
|
|
467
384
|
target: GradTarget = "closure",
|
|
468
385
|
seed: int | None | torch.Generator = None,
|
|
469
386
|
):
|
|
470
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,
|
|
387
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
471
388
|
|
|
472
389
|
class MeZO(GradApproximator):
|
|
473
390
|
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
@@ -525,9 +442,9 @@ class MeZO(GradApproximator):
|
|
|
525
442
|
loss_approx = None
|
|
526
443
|
|
|
527
444
|
h = NumberList(self.settings[p]['h'] for p in params)
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
445
|
+
n_samples = self.defaults['n_samples']
|
|
446
|
+
fd_fn = _RFD_FUNCS[self.defaults['formula']]
|
|
447
|
+
|
|
531
448
|
prt_fns = self.global_state['prt_fns']
|
|
532
449
|
|
|
533
450
|
grad = None
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from collections.abc import Mapping
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
|
|
@@ -17,6 +18,7 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
17
18
|
bounds (Sequence | None, optional):
|
|
18
19
|
For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
|
|
19
20
|
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
21
|
+
prev_init (bool, optional): uses previous step size as initial guess for the line search.
|
|
20
22
|
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
21
23
|
|
|
22
24
|
For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
|
|
@@ -29,9 +31,10 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
29
31
|
bracket=None,
|
|
30
32
|
bounds=None,
|
|
31
33
|
tol: float | None = None,
|
|
34
|
+
prev_init: bool = False,
|
|
32
35
|
options=None,
|
|
33
36
|
):
|
|
34
|
-
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
|
|
37
|
+
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
|
|
35
38
|
super().__init__(defaults)
|
|
36
39
|
|
|
37
40
|
import scipy.optimize
|
|
@@ -48,5 +51,14 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
48
51
|
options = dict(options) if isinstance(options, Mapping) else {}
|
|
49
52
|
options['maxiter'] = maxiter
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
if self.defaults["prev_init"] and "x_prev" in self.global_state:
|
|
55
|
+
if bracket is None: bracket = (0, 1)
|
|
56
|
+
bracket = (*bracket[:-1], self.global_state["x_prev"])
|
|
57
|
+
|
|
58
|
+
x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]
|
|
59
|
+
|
|
60
|
+
max = torch.finfo(var.params[0].dtype).max / 2
|
|
61
|
+
if (not math.isfinite(x)) or abs(x) >= max: x = 0
|
|
62
|
+
|
|
63
|
+
self.global_state['x_prev'] = x
|
|
64
|
+
return x
|
|
@@ -330,7 +330,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
330
330
|
if adaptive:
|
|
331
331
|
a_init *= self.global_state.get('initial_scale', 1)
|
|
332
332
|
|
|
333
|
-
|
|
334
333
|
strong_wolfe = _StrongWolfe(
|
|
335
334
|
f=objective,
|
|
336
335
|
f_0=f_0,
|
|
@@ -360,7 +359,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
360
359
|
if inverted: a = -a
|
|
361
360
|
|
|
362
361
|
if a is not None and a != 0 and math.isfinite(a):
|
|
363
|
-
#self.global_state['initial_scale'] = min(1.0, self.global_state.get('initial_scale', 1) * math.sqrt(2))
|
|
364
362
|
self.global_state['initial_scale'] = 1
|
|
365
363
|
self.global_state['a_prev'] = a
|
|
366
364
|
self.global_state['f_prev'] = f_0
|
|
@@ -60,18 +60,18 @@ class RestartStrategyBase(Module, ABC):
|
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
class RestartOnStuck(RestartStrategyBase):
|
|
63
|
-
"""Resets the state when update (difference in parameters) is
|
|
63
|
+
"""Resets the state when update (difference in parameters) is zero for multiple steps in a row.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
66
|
modules (Chainable | None):
|
|
67
67
|
modules to reset. If None, resets all modules.
|
|
68
68
|
tol (float, optional):
|
|
69
|
-
step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to
|
|
69
|
+
step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
|
|
70
70
|
n_tol (int, optional):
|
|
71
|
-
number of failed consequtive steps required to trigger a reset. Defaults to
|
|
71
|
+
number of failed consequtive steps required to trigger a reset. Defaults to 10.
|
|
72
72
|
|
|
73
73
|
"""
|
|
74
|
-
def __init__(self, modules: Chainable | None, tol: float =
|
|
74
|
+
def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
|
|
75
75
|
defaults = dict(tol=tol, n_tol=n_tol)
|
|
76
76
|
super().__init__(defaults, modules)
|
|
77
77
|
|
|
@@ -82,6 +82,7 @@ class RestartOnStuck(RestartStrategyBase):
|
|
|
82
82
|
|
|
83
83
|
params = TensorList(var.params)
|
|
84
84
|
tol = self.defaults['tol']
|
|
85
|
+
if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
|
|
85
86
|
n_tol = self.defaults['n_tol']
|
|
86
87
|
n_bad = self.global_state.get('n_bad', 0)
|
|
87
88
|
|