torchzero 0.4.3__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.4.3 → torchzero-0.4.4}/PKG-INFO +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/pyproject.toml +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_opts.py +2 -2
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/_minimize/methods.py +37 -32
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/module.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/benchmark.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/eigh.py +2 -2
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/linear_operator.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/matrix_power.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/orthogonalize.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/qr.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/solve.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/svd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/torch_linalg.py +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adagrad.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adahessian.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adam.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adan.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/aegd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/esgd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/ggt.py +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/lion.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/mars.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/matrix_momentum.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/msam.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/muon.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/natural_gradient.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/orthograd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/rmsprop.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/rprop.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/sam.py +6 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/shampoo.py +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/soap.py +15 -2
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/sophia_h.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/basis/ggt_basis.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/basis/soap_basis.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/cg.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/coordinate_momentum.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/cubic_adam.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/dct.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/fft.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/higher_order_newton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/l_infinity.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/newton_solver.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/newtonnewton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/structural_projections.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/spsa1.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/least_squares/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/least_squares/gn.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/_polyinterp.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/adaptive.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/debug.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/escape.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/gradient_accumulation.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/misc.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/multistep.py +2 -3
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/regularization.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/split.py +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/switch.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/momentum.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/higher_level.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/opt_utils.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/cast.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/projection.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lsr1.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/inm.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/multipoint.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/newton.py +8 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/newton_cg.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/nystrom.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/smoothing/sampling.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/step_size/__init__.py +1 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/step_size/adaptive.py +42 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/step_size/lr.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/termination/termination.py +2 -1
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/dogleg.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/levenberg_marquardt.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_cg.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_region.py +0 -0
- torchzero-0.4.4/torchzero/modules/weight_decay/__init__.py +8 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/weight_decay/weight_decay.py +84 -7
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/root.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/directsearch.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/fcmaes.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/mads.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/optuna.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/compile.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/params.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/PKG-INFO +1 -1
- torchzero-0.4.3/torchzero/modules/weight_decay/__init__.py +0 -2
- {torchzero-0.4.3 → torchzero-0.4.4}/setup.cfg +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_identical.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_module.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_module_autograd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_objective.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_tensorlist.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/_minimize/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/_minimize/minimize.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/chain.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/functional.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/modular.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/objective.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/reformulation.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/transform.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/linalg_utils.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/sketch.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/lre_optimizers.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/_psgd_utils.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/basis/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/matrix_nag.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/interpolation.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/homotopy.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/damping.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/sg2.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/restarts/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/restarts/restars.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/ifn.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/rsn.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/termination/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/variance_reduction/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/variance_reduction/svrg.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/weight_decay/reinit.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/zeroth_order/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/zeroth_order/cd.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/mbs.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/moors.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/pybobyqa.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/basin_hopping.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/brute.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/differential_evolution.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/direct.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/dual_annealing.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/experimental.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/minimize.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/sgho.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/wrapper.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/benchmarks/__init__.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/benchmarks/logistic.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/metrics.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/thoad_tools.py +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/SOURCES.txt +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -727,8 +727,8 @@ Adam = Run(
|
|
|
727
727
|
)
|
|
728
728
|
# ------------------------------ optimizers/soap ----------------------------- #
|
|
729
729
|
SOAP = Run(
|
|
730
|
-
func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(0.4)),
|
|
731
|
-
sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
|
|
730
|
+
func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(merge_small=True), tz.m.LR(0.4)),
|
|
731
|
+
sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1, merge_small=True), tz.m.LR(1)),
|
|
732
732
|
needs_closure=False,
|
|
733
733
|
# merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
734
734
|
func='rosen', steps=50, loss=4, merge_invariant=False,
|
|
File without changes
|
|
@@ -14,82 +14,87 @@ from ..utils import tofloat
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _get_method_from_str(method: str) -> list[Module]:
|
|
17
|
-
|
|
17
|
+
stripped = ''.join(c for c in method.lower().strip() if c.isalnum())
|
|
18
18
|
|
|
19
|
-
if
|
|
19
|
+
if stripped == "bfgs":
|
|
20
20
|
return [m.RestartOnStuck(m.BFGS()), m.Backtracking()]
|
|
21
21
|
|
|
22
|
-
if
|
|
22
|
+
if stripped == "lbfgs":
|
|
23
23
|
return [m.LBFGS(100), m.Backtracking()]
|
|
24
24
|
|
|
25
|
-
if
|
|
25
|
+
if stripped == "newton":
|
|
26
26
|
return [m.Newton(), m.Backtracking()]
|
|
27
27
|
|
|
28
|
-
if
|
|
28
|
+
if stripped == "sfn":
|
|
29
29
|
return [m.Newton(eigval_fn=lambda x: x.abs().clip(min=1e-10)), m.Backtracking()]
|
|
30
30
|
|
|
31
|
-
if
|
|
31
|
+
if stripped == "inm":
|
|
32
32
|
return [m.ImprovedNewton(), m.Backtracking()]
|
|
33
33
|
|
|
34
|
-
if
|
|
34
|
+
if stripped == 'crn':
|
|
35
35
|
return [m.CubicRegularization(m.Newton())]
|
|
36
36
|
|
|
37
|
-
if
|
|
37
|
+
if stripped == "commondirections":
|
|
38
38
|
return [m.SubspaceNewton(sketch_type='common_directions'), m.Backtracking()]
|
|
39
39
|
|
|
40
|
-
if
|
|
40
|
+
if stripped == "trust":
|
|
41
41
|
return [m.LevenbergMarquardt(m.Newton())]
|
|
42
42
|
|
|
43
|
-
if
|
|
44
|
-
return [m.TrustCG(m.Newton())]
|
|
45
|
-
|
|
46
|
-
if method == "dogleg":
|
|
43
|
+
if stripped == "dogleg":
|
|
47
44
|
return [m.Dogleg(m.Newton())]
|
|
48
45
|
|
|
49
|
-
if
|
|
50
|
-
return [m.LevenbergMarquardt(m.BFGS())]
|
|
46
|
+
if stripped == "trustbfgs":
|
|
47
|
+
return [m.RestartOnStuck(m.LevenbergMarquardt(m.BFGS()))]
|
|
51
48
|
|
|
52
|
-
if
|
|
53
|
-
return [m.LevenbergMarquardt(m.SR1())]
|
|
49
|
+
if stripped == "trustsr1":
|
|
50
|
+
return [m.RestartOnStuck(m.LevenbergMarquardt(m.SR1()))]
|
|
54
51
|
|
|
55
|
-
if
|
|
52
|
+
if stripped == "newtoncg":
|
|
56
53
|
return [m.NewtonCG(), m.Backtracking()]
|
|
57
54
|
|
|
58
|
-
if
|
|
55
|
+
if stripped == "tn":
|
|
59
56
|
return [m.NewtonCG(maxiter=10), m.Backtracking()]
|
|
60
57
|
|
|
61
|
-
if
|
|
58
|
+
if stripped == "trustncg":
|
|
62
59
|
return [m.NewtonCGSteihaug()]
|
|
63
60
|
|
|
64
|
-
if
|
|
61
|
+
if stripped == "gd":
|
|
65
62
|
return [m.Backtracking()]
|
|
66
63
|
|
|
67
|
-
if
|
|
64
|
+
if stripped == "cg":
|
|
68
65
|
return [m.FletcherReeves(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
69
66
|
|
|
70
|
-
if
|
|
67
|
+
if stripped in ("shor", "shorr"):
|
|
68
|
+
return [m.ShorR(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
69
|
+
|
|
70
|
+
if stripped == "pgm":
|
|
71
|
+
return [m.ProjectedGradientMethod(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
72
|
+
|
|
73
|
+
if stripped == "bb":
|
|
71
74
|
return [m.RestartOnStuck(m.BarzilaiBorwein())]
|
|
72
75
|
|
|
73
|
-
if
|
|
76
|
+
if stripped == "bbstab":
|
|
74
77
|
return [m.BBStab()]
|
|
75
78
|
|
|
76
|
-
if
|
|
79
|
+
if stripped == "adgd":
|
|
77
80
|
return [m.AdGD()]
|
|
78
81
|
|
|
79
|
-
if
|
|
82
|
+
if stripped in ("bd", "bolddriver"):
|
|
83
|
+
return [m.BoldDriver()]
|
|
84
|
+
|
|
85
|
+
if stripped in ("gn", "gaussnewton"):
|
|
80
86
|
return [m.GaussNewton(), m.Backtracking()]
|
|
81
87
|
|
|
82
|
-
if
|
|
88
|
+
if stripped == "rprop":
|
|
83
89
|
return [m.Rprop(alpha=1e-3)]
|
|
84
90
|
|
|
85
|
-
if
|
|
91
|
+
if stripped == "lm":
|
|
86
92
|
return [m.LevenbergMarquardt(m.GaussNewton())]
|
|
87
93
|
|
|
88
|
-
if
|
|
94
|
+
if stripped == "mlm":
|
|
89
95
|
return [m.LevenbergMarquardt(m.GaussNewton(), y=1)]
|
|
90
96
|
|
|
91
|
-
if
|
|
97
|
+
if stripped == "cd":
|
|
92
98
|
return [m.CD(), m.ScipyMinimizeScalar(maxiter=8)]
|
|
93
99
|
|
|
94
|
-
|
|
95
|
-
raise NotImplementedError(method)
|
|
100
|
+
raise NotImplementedError(stripped)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -285,8 +285,8 @@ def rank1_eigh(v: torch.Tensor):
|
|
|
285
285
|
vv = v.dot(v)
|
|
286
286
|
norm = vv.sqrt().clip(min=torch.finfo(vv.dtype).tiny * 2)
|
|
287
287
|
|
|
288
|
-
L = vv.unsqueeze(0) # (
|
|
289
|
-
Q = v.unsqueeze(-1) / norm # (m,
|
|
288
|
+
L = vv.unsqueeze(0) # (1, )
|
|
289
|
+
Q = v.unsqueeze(-1) / norm # (m, 1)
|
|
290
290
|
|
|
291
291
|
return L, Q
|
|
292
292
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -46,7 +46,7 @@ def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Ten
|
|
|
46
46
|
try:
|
|
47
47
|
return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
|
|
48
48
|
|
|
49
|
-
except torch.linalg.LinAlgError as e:
|
|
49
|
+
except (torch.linalg.LinAlgError, RuntimeError) as e:
|
|
50
50
|
if not retry_float64: raise e
|
|
51
51
|
dtype = A.dtype
|
|
52
52
|
if dtype == torch.float64: raise e
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
from collections.abc import Mapping, Sequence
|
|
1
2
|
from contextlib import nullcontext
|
|
3
|
+
from typing import Any
|
|
2
4
|
import torch
|
|
5
|
+
|
|
3
6
|
from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
|
|
4
|
-
from ...core import Transform
|
|
7
|
+
from ...core import Transform, Objective
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
class SAM(Transform):
|
|
@@ -126,6 +129,8 @@ class SAM(Transform):
|
|
|
126
129
|
|
|
127
130
|
objective.closure = sam_closure
|
|
128
131
|
|
|
132
|
+
def apply_states(self, objective: Objective, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> Objective:
|
|
133
|
+
return objective
|
|
129
134
|
# different class because defaults for SAM are bad for ASAM
|
|
130
135
|
class ASAM(SAM):
|
|
131
136
|
"""Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
|
|
@@ -31,7 +31,7 @@ def update_shampoo_preconditioner_(
|
|
|
31
31
|
if reg != 0:
|
|
32
32
|
accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
|
|
33
33
|
|
|
34
|
-
if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
|
|
34
|
+
if matrix_power is None: matrix_power = -1 / max(grad.ndim * 2, 2)
|
|
35
35
|
set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
|
|
36
36
|
|
|
37
37
|
def apply_shampoo_preconditioner(
|
|
@@ -51,6 +51,7 @@ def project_back(tensor: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
|
51
51
|
return tensor
|
|
52
52
|
|
|
53
53
|
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
54
|
+
# this is only used once per accumulator to initialize it
|
|
54
55
|
@torch.no_grad
|
|
55
56
|
def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
|
|
56
57
|
"""
|
|
@@ -64,7 +65,19 @@ def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
|
|
|
64
65
|
final.append(None)
|
|
65
66
|
continue
|
|
66
67
|
|
|
67
|
-
|
|
68
|
+
if not torch.isfinite(M).all():
|
|
69
|
+
raise RuntimeError(f"Initial gradient for parameter {M.shape} has non-finite values.")
|
|
70
|
+
|
|
71
|
+
M_f64 = M.to(torch.float64) + 1e-30 * torch.eye(M.shape[0], device=M.device, dtype=torch.float64)
|
|
72
|
+
try:
|
|
73
|
+
_, Q_f64 = torch_linalg.eigh(M_f64)
|
|
74
|
+
except RuntimeError as e:
|
|
75
|
+
if M_f64.is_cpu: raise e
|
|
76
|
+
M_f64 = M_f64.cpu()
|
|
77
|
+
_, Q_f64 = torch_linalg.eigh(M_f64) # apparently there is a bug in CUDA eigh
|
|
78
|
+
Q_f64 = Q_f64.to(M.device)
|
|
79
|
+
|
|
80
|
+
Q = Q_f64.to(M.dtype)
|
|
68
81
|
|
|
69
82
|
Q = torch.flip(Q, [1])
|
|
70
83
|
final.append(Q)
|
|
@@ -156,7 +169,7 @@ class SOAP(TensorTransform):
|
|
|
156
169
|
beta2: float = 0.95,
|
|
157
170
|
shampoo_beta: float | None = 0.95,
|
|
158
171
|
precond_freq: int = 10,
|
|
159
|
-
merge_small: bool =
|
|
172
|
+
merge_small: bool = False,
|
|
160
173
|
max_dim: int = 4096,
|
|
161
174
|
precondition_1d: bool = True,
|
|
162
175
|
eps: float = 1e-8,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/structural_projections.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/forward_gradient.py
RENAMED
|
File without changes
|
{torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/grad_approximator.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -154,8 +154,7 @@ class Online(Module):
|
|
|
154
154
|
closure = objective.closure
|
|
155
155
|
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
156
156
|
|
|
157
|
-
step = self.
|
|
158
|
-
self.global_state['step'] = step
|
|
157
|
+
step = self.increment_counter("step", start = 0)
|
|
159
158
|
|
|
160
159
|
params = TensorList(objective.params)
|
|
161
160
|
p_cur = params.clone()
|
|
@@ -165,7 +164,7 @@ class Online(Module):
|
|
|
165
164
|
var_c = objective.clone(clone_updates=False)
|
|
166
165
|
|
|
167
166
|
# on 1st step just step and store previous params
|
|
168
|
-
if step ==
|
|
167
|
+
if step == 0:
|
|
169
168
|
p_prev.copy_(params)
|
|
170
169
|
|
|
171
170
|
module.update(var_c)
|
|
File without changes
|
|
@@ -53,11 +53,11 @@ _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.T
|
|
|
53
53
|
Filter = _SingleFilter | Iterable[_SingleFilter]
|
|
54
54
|
|
|
55
55
|
def _make_filter(filter: Filter):
|
|
56
|
-
if callable(filter): return filter
|
|
57
56
|
if isinstance(filter, torch.Tensor):
|
|
58
57
|
return lambda x: x is filter
|
|
59
58
|
if isinstance(filter, torch.nn.Module):
|
|
60
59
|
return _make_filter(filter.parameters())
|
|
60
|
+
if callable(filter): return filter
|
|
61
61
|
|
|
62
62
|
# iterable
|
|
63
63
|
filters = [_make_filter(f) for f in filter]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -44,7 +44,14 @@ def _newton_update_state_(
|
|
|
44
44
|
|
|
45
45
|
# if any args require eigendecomp, we don't need H or H_inv, we store factors
|
|
46
46
|
if any(i is not None for i in [eigval_fn, eigv_tol, truncate]):
|
|
47
|
-
|
|
47
|
+
try:
|
|
48
|
+
state.pop("H", None)
|
|
49
|
+
L, Q = torch_linalg.eigh(H, retry_float64=True)
|
|
50
|
+
except torch.linalg.LinAlgError:
|
|
51
|
+
state.pop("L",None); state.pop("Q",None)
|
|
52
|
+
state["H"] = H
|
|
53
|
+
return
|
|
54
|
+
|
|
48
55
|
if eigval_fn is not None: L = eigval_fn(L)
|
|
49
56
|
L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eigv_tol)
|
|
50
57
|
state["L"] = L
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
|
|
2
|
-
from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD
|
|
2
|
+
from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD, BoldDriver
|