torchzero 0.4.1__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.4.1 → torchzero-0.4.2}/PKG-INFO +1 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/pyproject.toml +1 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/__init__.py +3 -1
- {torchzero-0.4.1/torchzero/optim/wrappers → torchzero-0.4.2/torchzero/_minimize}/__init__.py +0 -0
- torchzero-0.4.2/torchzero/_minimize/methods.py +95 -0
- torchzero-0.4.2/torchzero/_minimize/minimize.py +518 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/__init__.py +5 -5
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/chain.py +2 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/functional.py +2 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/module.py +75 -4
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/transform.py +6 -5
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/eigh.py +116 -68
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/linear_operator.py +1 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/orthogonalize.py +60 -5
- torchzero-0.4.2/torchzero/linalg/sketch.py +39 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/__init__.py +1 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/adagrad.py +2 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/adam.py +5 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/adan.py +3 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/ggt.py +20 -18
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/lion.py +3 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/mars.py +6 -5
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/msam.py +3 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/rmsprop.py +2 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/rprop.py +9 -7
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/shampoo.py +9 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/soap.py +32 -29
- torchzero-0.4.2/torchzero/modules/basis/__init__.py +2 -0
- torchzero-0.4.2/torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero-0.4.2/torchzero/modules/basis/soap_basis.py +254 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/clipping/ema_clipping.py +32 -27
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/clipping/growth_clipping.py +1 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/__init__.py +1 -6
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/coordinate_momentum.py +2 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/cubic_adam.py +4 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/grad_approximation/__init__.py +3 -2
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/least_squares/gn.py +6 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/gradient_accumulation.py +1 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/misc.py +6 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/momentum/averaging.py +6 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/momentum/momentum.py +4 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/__init__.py +0 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/accumulate.py +4 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/higher_level.py +6 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/inm.py +4 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/newton.py +11 -3
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/newton_cg.py +7 -3
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/nystrom.py +14 -19
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/rsn.py +37 -6
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/trust_region/trust_region.py +2 -1
- {torchzero-0.4.1/torchzero/utils/benchmarks → torchzero-0.4.2/torchzero/optim/wrappers}/__init__.py +0 -0
- torchzero-0.4.2/torchzero/utils/benchmarks/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/benchmarks/logistic.py +33 -18
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/params.py +13 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero.egg-info/PKG-INFO +1 -1
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero.egg-info/SOURCES.txt +8 -5
- torchzero-0.4.1/torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero-0.4.1/torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero-0.4.1/torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero-0.4.1/torchzero/modules/experimental/eigengrad.py +0 -207
- {torchzero-0.4.1 → torchzero-0.4.2}/setup.cfg +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_identical.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_module.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_module_autograd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_objective.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_opts.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_tensorlist.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/modular.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/objective.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/core/reformulation.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/benchmark.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/linalg_utils.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/matrix_power.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/qr.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/solve.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/svd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/linalg/torch_linalg.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/adahessian.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/aegd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/esgd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/lre_optimizers.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/matrix_momentum.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/muon.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/natural_gradient.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/orthograd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/_psgd_utils.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/psgd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/sam.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/adaptive/sophia_h.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/conjugate_gradient/cg.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/dct.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/fft.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/higher_order_newton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/l_infinity.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/matrix_nag.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/newton_solver.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/newtonnewton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/experimental/structural_projections.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.4.1/torchzero/modules/experimental → torchzero-0.4.2/torchzero/modules/grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/least_squares/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/_polyinterp.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/adaptive.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/interpolation.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/debug.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/escape.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/homotopy.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/multistep.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/regularization.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/split.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/misc/switch.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/opt_utils.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/projections/cast.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/projections/projection.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/damping.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/lsr1.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/quasi_newton/sg2.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/restarts/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/restarts/restars.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/ifn.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/second_order/multipoint.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/smoothing/sampling.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/step_size/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/step_size/adaptive.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/step_size/lr.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/termination/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/termination/termination.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/trust_region/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/trust_region/dogleg.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/trust_region/levenberg_marquardt.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/trust_region/trust_cg.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/variance_reduction/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/variance_reduction/svrg.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/weight_decay/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/weight_decay/reinit.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/weight_decay/weight_decay.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/zeroth_order/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/modules/zeroth_order/cd.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/mbs.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/root.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/directsearch.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/fcmaes.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/mads.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/moors.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/optuna.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/pybobyqa.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/basin_hopping.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/brute.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/differential_evolution.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/direct.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/dual_annealing.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/experimental.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/minimize.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/sgho.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/optim/wrappers/wrapper.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/compile.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/metrics.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/thoad_tools.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.4.1 → torchzero-0.4.2}/torchzero.egg-info/top_level.txt +0 -0
{torchzero-0.4.1/torchzero/optim/wrappers → torchzero-0.4.2/torchzero/_minimize}/__init__.py
RENAMED
|
File without changes
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""WIP API"""
|
|
2
|
+
import itertools
|
|
3
|
+
import time
|
|
4
|
+
from collections import deque
|
|
5
|
+
from collections.abc import Callable, Sequence, Mapping, Iterable
|
|
6
|
+
from typing import Any, NamedTuple, cast, overload
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .. import m
|
|
12
|
+
from ..core import Module, Optimizer
|
|
13
|
+
from ..utils import tofloat
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_method_from_str(method: str) -> list[Module]:
|
|
17
|
+
method = ''.join(c for c in method.lower().strip() if c.isalnum())
|
|
18
|
+
|
|
19
|
+
if method == "bfgs":
|
|
20
|
+
return [m.RestartOnStuck(m.BFGS()), m.Backtracking()]
|
|
21
|
+
|
|
22
|
+
if method == "lbfgs":
|
|
23
|
+
return [m.LBFGS(100), m.Backtracking()]
|
|
24
|
+
|
|
25
|
+
if method == "newton":
|
|
26
|
+
return [m.Newton(), m.Backtracking()]
|
|
27
|
+
|
|
28
|
+
if method == "sfn":
|
|
29
|
+
return [m.Newton(eigval_fn=lambda x: x.abs().clip(min=1e-10)), m.Backtracking()]
|
|
30
|
+
|
|
31
|
+
if method == "inm":
|
|
32
|
+
return [m.ImprovedNewton(), m.Backtracking()]
|
|
33
|
+
|
|
34
|
+
if method == 'crn':
|
|
35
|
+
return [m.CubicRegularization(m.Newton())]
|
|
36
|
+
|
|
37
|
+
if method == "commondirections":
|
|
38
|
+
return [m.SubspaceNewton(sketch_type='common_directions'), m.Backtracking()]
|
|
39
|
+
|
|
40
|
+
if method == "trust":
|
|
41
|
+
return [m.LevenbergMarquardt(m.Newton())]
|
|
42
|
+
|
|
43
|
+
if method == "trustexact":
|
|
44
|
+
return [m.TrustCG(m.Newton())]
|
|
45
|
+
|
|
46
|
+
if method == "dogleg":
|
|
47
|
+
return [m.Dogleg(m.Newton())]
|
|
48
|
+
|
|
49
|
+
if method == "trustbfgs":
|
|
50
|
+
return [m.LevenbergMarquardt(m.BFGS())]
|
|
51
|
+
|
|
52
|
+
if method == "trustsr1":
|
|
53
|
+
return [m.LevenbergMarquardt(m.SR1())]
|
|
54
|
+
|
|
55
|
+
if method == "newtoncg":
|
|
56
|
+
return [m.NewtonCG(), m.Backtracking()]
|
|
57
|
+
|
|
58
|
+
if method == "tn":
|
|
59
|
+
return [m.NewtonCG(maxiter=10), m.Backtracking()]
|
|
60
|
+
|
|
61
|
+
if method == "trustncg":
|
|
62
|
+
return [m.NewtonCGSteihaug()]
|
|
63
|
+
|
|
64
|
+
if method == "gd":
|
|
65
|
+
return [m.Backtracking()]
|
|
66
|
+
|
|
67
|
+
if method == "cg":
|
|
68
|
+
return [m.FletcherReeves(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
69
|
+
|
|
70
|
+
if method == "bb":
|
|
71
|
+
return [m.RestartOnStuck(m.BarzilaiBorwein())]
|
|
72
|
+
|
|
73
|
+
if method == "bbstab":
|
|
74
|
+
return [m.BBStab()]
|
|
75
|
+
|
|
76
|
+
if method == "adgd":
|
|
77
|
+
return [m.AdGD()]
|
|
78
|
+
|
|
79
|
+
if method in ("gn", "gaussnewton"):
|
|
80
|
+
return [m.GaussNewton(), m.Backtracking()]
|
|
81
|
+
|
|
82
|
+
if method == "rprop":
|
|
83
|
+
return [m.Rprop(alpha=1e-3)]
|
|
84
|
+
|
|
85
|
+
if method == "lm":
|
|
86
|
+
return [m.LevenbergMarquardt(m.GaussNewton())]
|
|
87
|
+
|
|
88
|
+
if method == "mlm":
|
|
89
|
+
return [m.LevenbergMarquardt(m.GaussNewton(), y=1)]
|
|
90
|
+
|
|
91
|
+
if method == "cd":
|
|
92
|
+
return [m.CD(), m.ScipyMinimizeScalar(maxiter=8)]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
raise NotImplementedError(method)
|
|
@@ -0,0 +1,518 @@
|
|
|
1
|
+
"""WIP API"""
|
|
2
|
+
import itertools
|
|
3
|
+
import time
|
|
4
|
+
from collections import deque
|
|
5
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
6
|
+
from typing import Any, NamedTuple, cast, overload
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ..core import Module, Optimizer
|
|
12
|
+
from ..utils import tofloat
|
|
13
|
+
from .methods import _get_method_from_str
|
|
14
|
+
|
|
15
|
+
_fn_autograd = Callable[[torch.Tensor], torch.Tensor | Any]
|
|
16
|
+
_fn_custom_grad = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
|
17
|
+
_scalar = float | np.ndarray | torch.Tensor
|
|
18
|
+
_method = str | Module | Sequence[Module] | Callable[..., torch.optim.Optimizer]
|
|
19
|
+
|
|
20
|
+
def _tensorlist_norm(tensors: Iterable[torch.Tensor], ord) -> torch.Tensor:
|
|
21
|
+
"""returns a scalar - global norm of tensors"""
|
|
22
|
+
if ord == torch.inf:
|
|
23
|
+
return max(torch._foreach_max(torch._foreach_abs(tuple(tensors))))
|
|
24
|
+
|
|
25
|
+
if ord == 1:
|
|
26
|
+
return cast(torch.Tensor, sum(t.abs().sum() for t in tensors))
|
|
27
|
+
|
|
28
|
+
if ord % 2 != 0:
|
|
29
|
+
tensors = torch._foreach_abs(tuple(tensors))
|
|
30
|
+
|
|
31
|
+
tensors = torch._foreach_pow(tuple(tensors), ord)
|
|
32
|
+
return sum(t.sum() for t in tensors) ** (1 / ord)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Params:
|
|
37
|
+
__slots__ = ("args", "kwargs")
|
|
38
|
+
def __init__(self, args: Sequence[torch.Tensor], kwargs: Mapping[str, torch.Tensor]):
|
|
39
|
+
self.args = tuple(args)
|
|
40
|
+
self.kwargs = dict(kwargs)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def x(self):
|
|
44
|
+
assert len(self.args) == 1
|
|
45
|
+
assert len(self.kwargs) == 0
|
|
46
|
+
return self.args[0]
|
|
47
|
+
|
|
48
|
+
def parameters(self):
|
|
49
|
+
yield from self.args
|
|
50
|
+
yield from self.kwargs.values()
|
|
51
|
+
|
|
52
|
+
def clone(self):
|
|
53
|
+
return Params(
|
|
54
|
+
args = [a.clone() for a in self.args],
|
|
55
|
+
kwargs={k:v.clone() for k,v in self.kwargs.items()}
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def __repr__(self):
|
|
59
|
+
if len(self.args) == 1 and len(self.kwargs) == 0:
|
|
60
|
+
return f"Params({repr(self.x)})"
|
|
61
|
+
|
|
62
|
+
s = "Params("
|
|
63
|
+
if len(self.args) > 0:
|
|
64
|
+
s = f"{s}\n\targs = (\n\t\t"
|
|
65
|
+
s += ",\n\t\t".join(str(a) for a in self.args)
|
|
66
|
+
s = s + "\n\t)"
|
|
67
|
+
|
|
68
|
+
if len(self.kwargs) > 0:
|
|
69
|
+
s = f'{s}\n\tkwargs = (\n\t\t'
|
|
70
|
+
for k,v in self.kwargs.items():
|
|
71
|
+
s = f"{s}{k}={v},\n\t\t"
|
|
72
|
+
s = s[:-2] + "\t)"
|
|
73
|
+
|
|
74
|
+
return f"{s}\n)"
|
|
75
|
+
|
|
76
|
+
def _call(self, f):
|
|
77
|
+
return f(*self.args, **self.kwargs)
|
|
78
|
+
|
|
79
|
+
def _detach_clone(self):
|
|
80
|
+
return Params(
|
|
81
|
+
args = [a.detach().clone() for a in self.args],
|
|
82
|
+
kwargs={k:v.detach().clone() for k,v in self.kwargs.items()}
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _detach_cpu_clone(self):
|
|
86
|
+
return Params(
|
|
87
|
+
args = [a.detach().cpu().clone() for a in self.args],
|
|
88
|
+
kwargs={k:v.detach().cpu().clone() for k,v in self.kwargs.items()}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _requires_grad_(self, mode=True):
|
|
92
|
+
return Params(
|
|
93
|
+
args = [a.requires_grad_(mode) for a in self.args],
|
|
94
|
+
kwargs={k:v.requires_grad_(mode) for k,v in self.kwargs.items()}
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _grads(self):
|
|
99
|
+
params = tuple(self.parameters())
|
|
100
|
+
if all(p.grad is None for p in params): return None
|
|
101
|
+
return [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
_x0 = (
|
|
105
|
+
torch.Tensor |
|
|
106
|
+
Sequence[torch.Tensor] |
|
|
107
|
+
Mapping[str, torch.Tensor] |
|
|
108
|
+
Mapping[str, Sequence[torch.Tensor] | Mapping[str, torch.Tensor]] |
|
|
109
|
+
tuple[Sequence[torch.Tensor], Mapping[str, torch.Tensor]] |
|
|
110
|
+
Sequence[Sequence[torch.Tensor] | Mapping[str, torch.Tensor]] |
|
|
111
|
+
Params
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _get_opt_fn(method: _method):
|
|
117
|
+
if isinstance(method, str):
|
|
118
|
+
return lambda p: Optimizer(p, *_get_method_from_str(method))
|
|
119
|
+
|
|
120
|
+
if isinstance(method, Module):
|
|
121
|
+
return lambda p: Optimizer(p, method)
|
|
122
|
+
|
|
123
|
+
if isinstance(method, Sequence):
|
|
124
|
+
return lambda p: Optimizer(p, *method)
|
|
125
|
+
|
|
126
|
+
if callable(method):
|
|
127
|
+
return method
|
|
128
|
+
|
|
129
|
+
raise ValueError(method)
|
|
130
|
+
|
|
131
|
+
def _is_scalar(x):
|
|
132
|
+
if isinstance(x, torch.Tensor): return x.numel() == 1
|
|
133
|
+
if isinstance(x, np.ndarray): return x.size == 1
|
|
134
|
+
return True
|
|
135
|
+
|
|
136
|
+
def _maybe_detach_cpu(x):
|
|
137
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu()
|
|
138
|
+
return x
|
|
139
|
+
|
|
140
|
+
class _MaxEvaluationsReached(Exception): pass
|
|
141
|
+
class _MaxSecondsReached(Exception): pass
|
|
142
|
+
class Terminate(Exception): pass
|
|
143
|
+
|
|
144
|
+
class _WrappedFunc:
|
|
145
|
+
def __init__(self, f: _fn_autograd | _fn_custom_grad, x0: Params, reduce_fn: Callable, max_history,
|
|
146
|
+
maxeval:int | None, maxsec: float | None, custom_grad:bool):
|
|
147
|
+
self.f = f
|
|
148
|
+
self.maxeval = maxeval
|
|
149
|
+
self.reduce_fn = reduce_fn
|
|
150
|
+
self.custom_grad = custom_grad
|
|
151
|
+
self.maxsec = maxsec
|
|
152
|
+
|
|
153
|
+
self.x_best = x0.clone()
|
|
154
|
+
self.fmin = float("inf")
|
|
155
|
+
self.evals = 0
|
|
156
|
+
self.start = time.time()
|
|
157
|
+
|
|
158
|
+
if max_history == -1: max_history = None # unlimited history
|
|
159
|
+
if max_history == 0: self.history = None
|
|
160
|
+
else: self.history = deque(maxlen=max_history)
|
|
161
|
+
|
|
162
|
+
def __call__(self, x: Params, g: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
163
|
+
if self.maxeval is not None and self.evals >= self.maxeval:
|
|
164
|
+
raise _MaxEvaluationsReached
|
|
165
|
+
|
|
166
|
+
if self.maxsec is not None and time.time() - self.start >= self.maxsec:
|
|
167
|
+
raise _MaxSecondsReached
|
|
168
|
+
|
|
169
|
+
self.evals += 1
|
|
170
|
+
|
|
171
|
+
if self.custom_grad:
|
|
172
|
+
assert g is not None
|
|
173
|
+
assert len(x.args) == 1 and len(x.kwargs) == 0
|
|
174
|
+
v = v_scalar = cast(_fn_custom_grad, self.f)(x.x, g)
|
|
175
|
+
else:
|
|
176
|
+
v = v_scalar = x._call(self.f)
|
|
177
|
+
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
|
|
180
|
+
# multi-value v, reduce using reduce func
|
|
181
|
+
if isinstance(v, torch.Tensor) and v.numel() > 1:
|
|
182
|
+
v_scalar = self.reduce_fn(v)
|
|
183
|
+
|
|
184
|
+
if v_scalar < self.fmin:
|
|
185
|
+
self.fmin = tofloat(v_scalar)
|
|
186
|
+
self.x_best = x._detach_clone()
|
|
187
|
+
|
|
188
|
+
if self.history is not None:
|
|
189
|
+
self.history.append((x._detach_cpu_clone(), _maybe_detach_cpu(v)))
|
|
190
|
+
|
|
191
|
+
return v, g
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class MinimizeResult(NamedTuple):
|
|
196
|
+
params: Params
|
|
197
|
+
x: torch.Tensor | None
|
|
198
|
+
success: bool
|
|
199
|
+
message: str
|
|
200
|
+
fun: float
|
|
201
|
+
n_iters: int
|
|
202
|
+
n_evals: int
|
|
203
|
+
g_norm: torch.Tensor | None
|
|
204
|
+
dir_norm: torch.Tensor | None
|
|
205
|
+
losses: list[float]
|
|
206
|
+
history: deque[tuple[torch.Tensor, torch.Tensor]]
|
|
207
|
+
|
|
208
|
+
def __repr__(self):
|
|
209
|
+
newline = "\n"
|
|
210
|
+
ident = " " * 10
|
|
211
|
+
return (
|
|
212
|
+
f"message: {self.message}\n"
|
|
213
|
+
f"success: {self.success}\n"
|
|
214
|
+
f"fun: {self.fun}\n"
|
|
215
|
+
f"params: {repr(self.params).replace(newline, newline+ident)}\n"
|
|
216
|
+
f"x: {self.x}\n"
|
|
217
|
+
f"n_iters: {self.n_iters}\n"
|
|
218
|
+
f"n_evals: {self.n_evals}\n"
|
|
219
|
+
f"g_norm: {self.g_norm}\n"
|
|
220
|
+
f"dir_norm: {self.dir_norm}\n"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _make_params(x0: _x0):
|
|
226
|
+
x = cast(Any, x0)
|
|
227
|
+
|
|
228
|
+
# kwargs
|
|
229
|
+
if isinstance(x, Params): return x
|
|
230
|
+
|
|
231
|
+
# single tensor
|
|
232
|
+
if isinstance(x, torch.Tensor): return Params(args = (x, ), kwargs = {})
|
|
233
|
+
|
|
234
|
+
if isinstance(x, Sequence):
|
|
235
|
+
# args
|
|
236
|
+
if isinstance(x[0], torch.Tensor): return Params(args=x, kwargs = {})
|
|
237
|
+
|
|
238
|
+
# tuple of (args, kwrgs)
|
|
239
|
+
assert len(x) == 2 and isinstance(x[0], Sequence) and isinstance(x[1], Mapping)
|
|
240
|
+
return Params(args=x[0], kwargs=x[1])
|
|
241
|
+
|
|
242
|
+
if isinstance(x, Mapping):
|
|
243
|
+
# dict with args and kwargs
|
|
244
|
+
if "args" in x or "kwargs" in x: return Params(args=x.get("args", ()), kwargs=x.get("kwargs", {}))
|
|
245
|
+
|
|
246
|
+
# kwargs
|
|
247
|
+
return Params(args=(), kwargs=x)
|
|
248
|
+
|
|
249
|
+
raise TypeError(type(x))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def minimize(
|
|
253
|
+
f: _fn_autograd | _fn_custom_grad,
|
|
254
|
+
x0: _x0,
|
|
255
|
+
|
|
256
|
+
method: _method | None = None,
|
|
257
|
+
|
|
258
|
+
maxeval: int | None = None,
|
|
259
|
+
maxiter: int | None = None,
|
|
260
|
+
maxsec: float | None = None,
|
|
261
|
+
ftol: _scalar | None = None,
|
|
262
|
+
gtol: _scalar | None = 1e-5,
|
|
263
|
+
xtol: _scalar | None = None,
|
|
264
|
+
max_no_improvement_iters: int | None = 100,
|
|
265
|
+
|
|
266
|
+
reduce_fn: Callable[[torch.Tensor], torch.Tensor] = torch.sum,
|
|
267
|
+
max_history: int = 0,
|
|
268
|
+
|
|
269
|
+
custom_grad: bool = False,
|
|
270
|
+
use_termination_exceptions: bool = True,
|
|
271
|
+
norm = torch.inf,
|
|
272
|
+
|
|
273
|
+
) -> MinimizeResult:
|
|
274
|
+
"""Minimize a scalar or multiobjective function of one or more variables.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
f (_fn_autograd | _fn_custom_grad):
|
|
278
|
+
The objective function to be minimized.
|
|
279
|
+
x0 (_x0):
|
|
280
|
+
Initial guess. Can be torch.Tensor, tuple of torch.Tensors to pass as args,
|
|
281
|
+
or dictionary of torch.Tensors to pass as kwargs.
|
|
282
|
+
method (_method | None, optional):
|
|
283
|
+
Type of solver. Can be a string, a ``Module`` (like ``tz.m.BFGS()``), or a list of ``Module``.
|
|
284
|
+
By default chooses BFGS or L-BFGS depending on number of variables. Defaults to None.
|
|
285
|
+
maxeval (int | None, optional):
|
|
286
|
+
terminate when exceeded this number of function evaluations. Defaults to None.
|
|
287
|
+
maxiter (int | None, optional):
|
|
288
|
+
terminate when exceeded this number of solver iterations,
|
|
289
|
+
each iteration may perform multiple function evaluations. Defaults to None.
|
|
290
|
+
maxsec (float | None, optional):
|
|
291
|
+
terminate after optimizing for this many seconds. Defaults to None.
|
|
292
|
+
ftol (_scalar | None, optional):
|
|
293
|
+
terminate when reached a solution with objective value less or equal to this value. Defaults to None.
|
|
294
|
+
gtol (_scalar | None, optional):
|
|
295
|
+
terminate when gradient norm is less or equal to this value.
|
|
296
|
+
The type of norm is controlled by ``norm`` argument and is infinity norm by default. Defaults to 1e-5.
|
|
297
|
+
xtol (_scalar | None, optional):
|
|
298
|
+
terminate when norm of difference between successive parameters is less or equal to this value. Defaults to None.
|
|
299
|
+
max_no_improvement_iters (int | None, optional):
|
|
300
|
+
terminate when objective value hasn't improved once for this many consecutive iterations. Defaults to 100.
|
|
301
|
+
reduce_fn (Callable[[torch.Tensor], torch.Tensor], optional):
|
|
302
|
+
only has effect when ``f`` is multi-objective / least-squares. Determines how to convert
|
|
303
|
+
vector returned by ``f`` to a single scalar value for ``ftol`` and ``max_no_improvement_iters``.
|
|
304
|
+
Defaults to torch.sum.
|
|
305
|
+
max_history (int, optional):
|
|
306
|
+
stores this many last evaluated parameters and their values.
|
|
307
|
+
Set to -1 to store all parameters. Set to 0 to store nothing (default).
|
|
308
|
+
custom_grad (bool, optional):
|
|
309
|
+
Allows specifying a custom gradient function instead of using autograd.
|
|
310
|
+
if True, objective function ``f`` must of the following form:
|
|
311
|
+
```python
|
|
312
|
+
def f(x, grad):
|
|
313
|
+
value = objective(x)
|
|
314
|
+
if grad.numel() > 0:
|
|
315
|
+
grad[:] = objective_gradient(x)
|
|
316
|
+
return value
|
|
317
|
+
```
|
|
318
|
+
|
|
319
|
+
Defaults to False.
|
|
320
|
+
use_termination_exceptions (bool, optional):
|
|
321
|
+
if True, ``maxeval`` and ``maxsec`` use exceptions to terminate, therefore they are able to trigger
|
|
322
|
+
mid-iteration. If False, they can only trigger after iteration, so it might perform slightly more
|
|
323
|
+
evals and for slightly more seconds than requested. Defaults to True.
|
|
324
|
+
norm (float, optional):
|
|
325
|
+
type of norm to use for gradient and update tolerances. Defaults to torch.inf.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
RuntimeError: _description_
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
MinimizeResult: _description_
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
x0 = _make_params(x0)
|
|
335
|
+
x = x0._requires_grad_(True)
|
|
336
|
+
|
|
337
|
+
# checks
|
|
338
|
+
if custom_grad:
|
|
339
|
+
if not (len(x.args) == 1 and len(x.kwargs) == 0):
|
|
340
|
+
raise RuntimeError("custom_grad only works when `x` is a single tensor.")
|
|
341
|
+
|
|
342
|
+
# determine method if None
|
|
343
|
+
if method is None:
|
|
344
|
+
max_dim = 5_000 if next(iter(x.parameters())).is_cuda else 1_000
|
|
345
|
+
if sum(p.numel() for p in x.parameters()) > max_dim: method = 'lbfgs'
|
|
346
|
+
else: method = 'bfgs'
|
|
347
|
+
|
|
348
|
+
opt_fn = _get_opt_fn(method)
|
|
349
|
+
optimizer = opt_fn(list(x.parameters()))
|
|
350
|
+
|
|
351
|
+
f_wrapped = _WrappedFunc(
|
|
352
|
+
f,
|
|
353
|
+
x0=x0,
|
|
354
|
+
reduce_fn=reduce_fn,
|
|
355
|
+
max_history=max_history,
|
|
356
|
+
maxeval=maxeval,
|
|
357
|
+
custom_grad=custom_grad,
|
|
358
|
+
maxsec=maxsec,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
def closure(backward=True):
|
|
362
|
+
|
|
363
|
+
g = None
|
|
364
|
+
v = None
|
|
365
|
+
if custom_grad:
|
|
366
|
+
v = x.x
|
|
367
|
+
if backward: g = torch.empty_like(v)
|
|
368
|
+
else: g = torch.empty(0, device=v.device, dtype=v.dtype)
|
|
369
|
+
|
|
370
|
+
loss, g = f_wrapped(x, g=g)
|
|
371
|
+
|
|
372
|
+
if backward:
|
|
373
|
+
|
|
374
|
+
# custom gradients provided by user
|
|
375
|
+
if g is not None:
|
|
376
|
+
assert v is not None
|
|
377
|
+
v.grad = g
|
|
378
|
+
|
|
379
|
+
# autograd
|
|
380
|
+
else:
|
|
381
|
+
optimizer.zero_grad()
|
|
382
|
+
loss.backward()
|
|
383
|
+
|
|
384
|
+
return loss
|
|
385
|
+
|
|
386
|
+
losses = []
|
|
387
|
+
|
|
388
|
+
tiny = torch.finfo(list(x0.parameters())[0].dtype).tiny ** 2
|
|
389
|
+
if gtol == 0: gtol = tiny
|
|
390
|
+
if xtol == 0: xtol = tiny
|
|
391
|
+
|
|
392
|
+
p_prev = None if xtol is None else [p.detach().clone() for p in x.parameters()]
|
|
393
|
+
fmin = float("inf")
|
|
394
|
+
niter = 0
|
|
395
|
+
n_no_improvement = 0
|
|
396
|
+
g_norm = None
|
|
397
|
+
dir_norm = None
|
|
398
|
+
|
|
399
|
+
terminate_msg = "max iterations reached"
|
|
400
|
+
success = False
|
|
401
|
+
|
|
402
|
+
exceptions: list | tuple = [Terminate]
|
|
403
|
+
if use_termination_exceptions:
|
|
404
|
+
if maxeval is not None: exceptions.append(_MaxEvaluationsReached)
|
|
405
|
+
if maxsec is not None: exceptions.append(_MaxSecondsReached)
|
|
406
|
+
exceptions = tuple(exceptions)
|
|
407
|
+
|
|
408
|
+
for i in (range(maxiter) if maxiter is not None else itertools.count()):
|
|
409
|
+
niter += 1
|
|
410
|
+
|
|
411
|
+
# ----------------------------------- step ----------------------------------- #
|
|
412
|
+
try:
|
|
413
|
+
v = v_scalar = optimizer.step(closure) # pyright:ignore[reportCallIssue,reportArgumentType]
|
|
414
|
+
except exceptions:
|
|
415
|
+
break
|
|
416
|
+
|
|
417
|
+
with torch.no_grad():
|
|
418
|
+
assert v is not None and v_scalar is not None
|
|
419
|
+
|
|
420
|
+
if isinstance(v, torch.Tensor) and v.numel() > 1:
|
|
421
|
+
v_scalar = reduce_fn(v)
|
|
422
|
+
|
|
423
|
+
losses.append(tofloat(v_scalar))
|
|
424
|
+
|
|
425
|
+
# --------------------------- termination criteria --------------------------- #
|
|
426
|
+
|
|
427
|
+
# termination criteria on optimizer
|
|
428
|
+
if isinstance(optimizer, Optimizer) and optimizer.should_terminate:
|
|
429
|
+
terminate_msg = 'optimizer-specific termination criteria triggered'
|
|
430
|
+
success = True
|
|
431
|
+
break
|
|
432
|
+
|
|
433
|
+
# max seconds (when use_termination_exceptions=False)
|
|
434
|
+
if maxsec is not None and time.time() - f_wrapped.start >= maxsec:
|
|
435
|
+
terminate_msg = 'max seconds reached'
|
|
436
|
+
success = False
|
|
437
|
+
break
|
|
438
|
+
|
|
439
|
+
# max evals (when use_termination_exceptions=False)
|
|
440
|
+
if maxeval is not None and f_wrapped.evals >= maxeval:
|
|
441
|
+
terminate_msg = 'max evaluations reached'
|
|
442
|
+
success = False
|
|
443
|
+
break
|
|
444
|
+
|
|
445
|
+
# min function value
|
|
446
|
+
if ftol is not None and v_scalar <= ftol:
|
|
447
|
+
terminate_msg = 'target function value reached'
|
|
448
|
+
success = True
|
|
449
|
+
break
|
|
450
|
+
|
|
451
|
+
# gradient infinity norm
|
|
452
|
+
if gtol is not None:
|
|
453
|
+
grads = x._grads()
|
|
454
|
+
if grads is not None:
|
|
455
|
+
g_norm = _tensorlist_norm(grads, norm)
|
|
456
|
+
if g_norm <= gtol:
|
|
457
|
+
terminate_msg = 'gradient norm is below tolerance'
|
|
458
|
+
success = True
|
|
459
|
+
break
|
|
460
|
+
|
|
461
|
+
# due to the way torchzero works we sometimes don't populate .grad,
|
|
462
|
+
# e.g. with Newton, therefore fallback on xtol
|
|
463
|
+
else:
|
|
464
|
+
if xtol is None: xtol = tiny
|
|
465
|
+
|
|
466
|
+
# difference in parameters
|
|
467
|
+
if xtol is not None:
|
|
468
|
+
p_new = [p.detach().clone() for p in x.parameters()]
|
|
469
|
+
|
|
470
|
+
if p_prev is None: # happens when xtol is set in gtol logic
|
|
471
|
+
p_prev = p_new
|
|
472
|
+
|
|
473
|
+
else:
|
|
474
|
+
dir_norm = _tensorlist_norm(torch._foreach_sub(p_new, p_prev), norm)
|
|
475
|
+
if dir_norm <= xtol:
|
|
476
|
+
terminate_msg = 'update norm is below tolerance'
|
|
477
|
+
success = True
|
|
478
|
+
break
|
|
479
|
+
|
|
480
|
+
p_prev = p_new
|
|
481
|
+
|
|
482
|
+
# no improvement steps
|
|
483
|
+
if max_no_improvement_iters is not None:
|
|
484
|
+
if f_wrapped.fmin >= fmin:
|
|
485
|
+
n_no_improvement += 1
|
|
486
|
+
else:
|
|
487
|
+
fmin = f_wrapped.fmin
|
|
488
|
+
n_no_improvement = 0
|
|
489
|
+
|
|
490
|
+
if n_no_improvement >= max_no_improvement_iters:
|
|
491
|
+
terminate_msg = 'reached maximum steps without improvement'
|
|
492
|
+
success = False
|
|
493
|
+
break
|
|
494
|
+
|
|
495
|
+
history=f_wrapped.history
|
|
496
|
+
if history is None: history = deque()
|
|
497
|
+
|
|
498
|
+
x_vec = None
|
|
499
|
+
if len(x0.args) == 1 and len(x0.kwargs) == 0:
|
|
500
|
+
x_vec = f_wrapped.x_best.x
|
|
501
|
+
|
|
502
|
+
result = MinimizeResult(
|
|
503
|
+
params = f_wrapped.x_best,
|
|
504
|
+
x = x_vec,
|
|
505
|
+
success = success,
|
|
506
|
+
message = terminate_msg,
|
|
507
|
+
fun = f_wrapped.fmin,
|
|
508
|
+
n_iters = niter,
|
|
509
|
+
n_evals = f_wrapped.evals,
|
|
510
|
+
g_norm = g_norm,
|
|
511
|
+
dir_norm = dir_norm,
|
|
512
|
+
losses = losses,
|
|
513
|
+
history = history,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
return result
|
|
517
|
+
|
|
518
|
+
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .
|
|
3
|
-
from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
|
|
1
|
+
from .chain import Chain, maybe_chain
|
|
2
|
+
from .functional import apply, step, step_tensors, update
|
|
4
3
|
|
|
5
4
|
# order is important to avoid circular imports
|
|
6
5
|
from .modular import Optimizer
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
6
|
+
from .module import Module, Chainable, ProjectedBuffer
|
|
7
|
+
from .objective import Objective, DerivativesMethod, HessianMethod, HVPMethod
|
|
8
|
+
from .transform import TensorTransform, Transform
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from collections.abc import Iterable
|
|
2
2
|
|
|
3
3
|
from ..utils.python_tools import flatten
|
|
4
|
-
from .module import Module, Chainable
|
|
5
4
|
from .functional import _chain_step
|
|
5
|
+
from .module import Chainable, Module
|
|
6
|
+
|
|
6
7
|
|
|
7
8
|
class Chain(Module):
|
|
8
9
|
"""Chain modules, mostly used internally"""
|
|
@@ -83,6 +83,7 @@ def step_tensors(
|
|
|
83
83
|
modules = (modules, )
|
|
84
84
|
|
|
85
85
|
# make fake params if they are only used for shapes
|
|
86
|
+
# note that if modules use states, tensors must always be the same python object
|
|
86
87
|
if params is None:
|
|
87
88
|
params = [t.view_as(t).requires_grad_() for t in tensors]
|
|
88
89
|
|
|
@@ -96,7 +97,7 @@ def step_tensors(
|
|
|
96
97
|
objective.updates = list(tensors)
|
|
97
98
|
|
|
98
99
|
# step with modules
|
|
99
|
-
# this won't update parameters in-place because objective.Optimizer is None
|
|
100
|
+
# this won't update parameters in-place (on modules with fused update) because objective.Optimizer is None
|
|
100
101
|
objective = _chain_step(objective, modules)
|
|
101
102
|
|
|
102
103
|
# return updates
|