torchzero 0.4.2__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.4.2 → torchzero-0.4.4}/PKG-INFO +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/pyproject.toml +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_identical.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_opts.py +2 -2
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/_minimize/methods.py +37 -32
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/module.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/benchmark.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/eigh.py +2 -2
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/linear_operator.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/matrix_power.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/orthogonalize.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/qr.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/solve.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/svd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/torch_linalg.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adagrad.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adahessian.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adam.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adan.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/aegd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/esgd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/ggt.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/lion.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/mars.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/matrix_momentum.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/msam.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/muon.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/natural_gradient.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/orthograd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/rmsprop.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/rprop.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/sam.py +6 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/shampoo.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/soap.py +15 -2
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/sophia_h.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/basis/ggt_basis.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/basis/soap_basis.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/cg.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/coordinate_momentum.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/cubic_adam.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/dct.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/fft.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/higher_order_newton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/l_infinity.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/newton_solver.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/newtonnewton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/structural_projections.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/spsa1.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/least_squares/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/least_squares/gn.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/_polyinterp.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/adaptive.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/debug.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/escape.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/gradient_accumulation.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/misc.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/multistep.py +2 -3
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/regularization.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/split.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/switch.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/momentum.py +9 -9
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/higher_level.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/opt_utils.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/cast.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/projection.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lsr1.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/inm.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/multipoint.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/newton.py +8 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/newton_cg.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/nystrom.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/smoothing/sampling.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/step_size/__init__.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/step_size/adaptive.py +42 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/step_size/lr.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/termination/termination.py +2 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/dogleg.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/levenberg_marquardt.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_cg.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_region.py +0 -0
- torchzero-0.4.4/torchzero/modules/weight_decay/__init__.py +8 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/weight_decay/weight_decay.py +84 -7
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/root.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/directsearch.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/fcmaes.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/mads.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/optuna.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/compile.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/optuna_tools.py +1 -1
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/params.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/PKG-INFO +1 -1
- torchzero-0.4.2/torchzero/modules/weight_decay/__init__.py +0 -2
- {torchzero-0.4.2 → torchzero-0.4.4}/setup.cfg +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_module.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_module_autograd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_objective.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_tensorlist.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/_minimize/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/_minimize/minimize.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/chain.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/functional.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/modular.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/objective.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/reformulation.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/transform.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/linalg_utils.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/sketch.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/lre_optimizers.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/_psgd_utils.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/basis/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/matrix_nag.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/interpolation.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/homotopy.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/damping.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/sg2.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/restarts/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/restarts/restars.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/ifn.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/rsn.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/termination/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/variance_reduction/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/variance_reduction/svrg.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/weight_decay/reinit.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/zeroth_order/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/zeroth_order/cd.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/mbs.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/moors.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/pybobyqa.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/basin_hopping.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/brute.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/differential_evolution.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/direct.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/dual_annealing.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/experimental.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/minimize.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/sgho.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/wrapper.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/benchmarks/__init__.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/benchmarks/logistic.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/metrics.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/thoad_tools.py +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/SOURCES.txt +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -105,7 +105,7 @@ def test_adam(amsgrad):
|
|
|
105
105
|
tz_fn_ops = lambda p: tz.Optimizer(
|
|
106
106
|
p,
|
|
107
107
|
tz.m.DivModules(
|
|
108
|
-
tz.m.EMA(0.9,
|
|
108
|
+
tz.m.EMA(0.9, debias=True),
|
|
109
109
|
[tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
|
|
110
110
|
))
|
|
111
111
|
tz_fn_ops2 = lambda p: tz.Optimizer(
|
|
@@ -727,8 +727,8 @@ Adam = Run(
|
|
|
727
727
|
)
|
|
728
728
|
# ------------------------------ optimizers/soap ----------------------------- #
|
|
729
729
|
SOAP = Run(
|
|
730
|
-
func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(0.4)),
|
|
731
|
-
sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
|
|
730
|
+
func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(merge_small=True), tz.m.LR(0.4)),
|
|
731
|
+
sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1, merge_small=True), tz.m.LR(1)),
|
|
732
732
|
needs_closure=False,
|
|
733
733
|
# merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
734
734
|
func='rosen', steps=50, loss=4, merge_invariant=False,
|
|
File without changes
|
|
@@ -14,82 +14,87 @@ from ..utils import tofloat
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def _get_method_from_str(method: str) -> list[Module]:
|
|
17
|
-
|
|
17
|
+
stripped = ''.join(c for c in method.lower().strip() if c.isalnum())
|
|
18
18
|
|
|
19
|
-
if
|
|
19
|
+
if stripped == "bfgs":
|
|
20
20
|
return [m.RestartOnStuck(m.BFGS()), m.Backtracking()]
|
|
21
21
|
|
|
22
|
-
if
|
|
22
|
+
if stripped == "lbfgs":
|
|
23
23
|
return [m.LBFGS(100), m.Backtracking()]
|
|
24
24
|
|
|
25
|
-
if
|
|
25
|
+
if stripped == "newton":
|
|
26
26
|
return [m.Newton(), m.Backtracking()]
|
|
27
27
|
|
|
28
|
-
if
|
|
28
|
+
if stripped == "sfn":
|
|
29
29
|
return [m.Newton(eigval_fn=lambda x: x.abs().clip(min=1e-10)), m.Backtracking()]
|
|
30
30
|
|
|
31
|
-
if
|
|
31
|
+
if stripped == "inm":
|
|
32
32
|
return [m.ImprovedNewton(), m.Backtracking()]
|
|
33
33
|
|
|
34
|
-
if
|
|
34
|
+
if stripped == 'crn':
|
|
35
35
|
return [m.CubicRegularization(m.Newton())]
|
|
36
36
|
|
|
37
|
-
if
|
|
37
|
+
if stripped == "commondirections":
|
|
38
38
|
return [m.SubspaceNewton(sketch_type='common_directions'), m.Backtracking()]
|
|
39
39
|
|
|
40
|
-
if
|
|
40
|
+
if stripped == "trust":
|
|
41
41
|
return [m.LevenbergMarquardt(m.Newton())]
|
|
42
42
|
|
|
43
|
-
if
|
|
44
|
-
return [m.TrustCG(m.Newton())]
|
|
45
|
-
|
|
46
|
-
if method == "dogleg":
|
|
43
|
+
if stripped == "dogleg":
|
|
47
44
|
return [m.Dogleg(m.Newton())]
|
|
48
45
|
|
|
49
|
-
if
|
|
50
|
-
return [m.LevenbergMarquardt(m.BFGS())]
|
|
46
|
+
if stripped == "trustbfgs":
|
|
47
|
+
return [m.RestartOnStuck(m.LevenbergMarquardt(m.BFGS()))]
|
|
51
48
|
|
|
52
|
-
if
|
|
53
|
-
return [m.LevenbergMarquardt(m.SR1())]
|
|
49
|
+
if stripped == "trustsr1":
|
|
50
|
+
return [m.RestartOnStuck(m.LevenbergMarquardt(m.SR1()))]
|
|
54
51
|
|
|
55
|
-
if
|
|
52
|
+
if stripped == "newtoncg":
|
|
56
53
|
return [m.NewtonCG(), m.Backtracking()]
|
|
57
54
|
|
|
58
|
-
if
|
|
55
|
+
if stripped == "tn":
|
|
59
56
|
return [m.NewtonCG(maxiter=10), m.Backtracking()]
|
|
60
57
|
|
|
61
|
-
if
|
|
58
|
+
if stripped == "trustncg":
|
|
62
59
|
return [m.NewtonCGSteihaug()]
|
|
63
60
|
|
|
64
|
-
if
|
|
61
|
+
if stripped == "gd":
|
|
65
62
|
return [m.Backtracking()]
|
|
66
63
|
|
|
67
|
-
if
|
|
64
|
+
if stripped == "cg":
|
|
68
65
|
return [m.FletcherReeves(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
69
66
|
|
|
70
|
-
if
|
|
67
|
+
if stripped in ("shor", "shorr"):
|
|
68
|
+
return [m.ShorR(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
69
|
+
|
|
70
|
+
if stripped == "pgm":
|
|
71
|
+
return [m.ProjectedGradientMethod(), m.StrongWolfe(c2=0.1, fallback=True)]
|
|
72
|
+
|
|
73
|
+
if stripped == "bb":
|
|
71
74
|
return [m.RestartOnStuck(m.BarzilaiBorwein())]
|
|
72
75
|
|
|
73
|
-
if
|
|
76
|
+
if stripped == "bbstab":
|
|
74
77
|
return [m.BBStab()]
|
|
75
78
|
|
|
76
|
-
if
|
|
79
|
+
if stripped == "adgd":
|
|
77
80
|
return [m.AdGD()]
|
|
78
81
|
|
|
79
|
-
if
|
|
82
|
+
if stripped in ("bd", "bolddriver"):
|
|
83
|
+
return [m.BoldDriver()]
|
|
84
|
+
|
|
85
|
+
if stripped in ("gn", "gaussnewton"):
|
|
80
86
|
return [m.GaussNewton(), m.Backtracking()]
|
|
81
87
|
|
|
82
|
-
if
|
|
88
|
+
if stripped == "rprop":
|
|
83
89
|
return [m.Rprop(alpha=1e-3)]
|
|
84
90
|
|
|
85
|
-
if
|
|
91
|
+
if stripped == "lm":
|
|
86
92
|
return [m.LevenbergMarquardt(m.GaussNewton())]
|
|
87
93
|
|
|
88
|
-
if
|
|
94
|
+
if stripped == "mlm":
|
|
89
95
|
return [m.LevenbergMarquardt(m.GaussNewton(), y=1)]
|
|
90
96
|
|
|
91
|
-
if
|
|
97
|
+
if stripped == "cd":
|
|
92
98
|
return [m.CD(), m.ScipyMinimizeScalar(maxiter=8)]
|
|
93
99
|
|
|
94
|
-
|
|
95
|
-
raise NotImplementedError(method)
|
|
100
|
+
raise NotImplementedError(stripped)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -285,8 +285,8 @@ def rank1_eigh(v: torch.Tensor):
|
|
|
285
285
|
vv = v.dot(v)
|
|
286
286
|
norm = vv.sqrt().clip(min=torch.finfo(vv.dtype).tiny * 2)
|
|
287
287
|
|
|
288
|
-
L = vv.unsqueeze(0) # (
|
|
289
|
-
Q = v.unsqueeze(-1) / norm # (m,
|
|
288
|
+
L = vv.unsqueeze(0) # (1, )
|
|
289
|
+
Q = v.unsqueeze(-1) / norm # (m, 1)
|
|
290
290
|
|
|
291
291
|
return L, Q
|
|
292
292
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -46,7 +46,7 @@ def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Ten
|
|
|
46
46
|
try:
|
|
47
47
|
return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
|
|
48
48
|
|
|
49
|
-
except torch.linalg.LinAlgError as e:
|
|
49
|
+
except (torch.linalg.LinAlgError, RuntimeError) as e:
|
|
50
50
|
if not retry_float64: raise e
|
|
51
51
|
dtype = A.dtype
|
|
52
52
|
if dtype == torch.float64: raise e
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
from collections.abc import Mapping, Sequence
|
|
1
2
|
from contextlib import nullcontext
|
|
3
|
+
from typing import Any
|
|
2
4
|
import torch
|
|
5
|
+
|
|
3
6
|
from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
|
|
4
|
-
from ...core import Transform
|
|
7
|
+
from ...core import Transform, Objective
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
class SAM(Transform):
|
|
@@ -126,6 +129,8 @@ class SAM(Transform):
|
|
|
126
129
|
|
|
127
130
|
objective.closure = sam_closure
|
|
128
131
|
|
|
132
|
+
def apply_states(self, objective: Objective, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> Objective:
|
|
133
|
+
return objective
|
|
129
134
|
# different class because defaults for SAM are bad for ASAM
|
|
130
135
|
class ASAM(SAM):
|
|
131
136
|
"""Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
|
|
@@ -31,7 +31,7 @@ def update_shampoo_preconditioner_(
|
|
|
31
31
|
if reg != 0:
|
|
32
32
|
accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
|
|
33
33
|
|
|
34
|
-
if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
|
|
34
|
+
if matrix_power is None: matrix_power = -1 / max(grad.ndim * 2, 2)
|
|
35
35
|
set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
|
|
36
36
|
|
|
37
37
|
def apply_shampoo_preconditioner(
|
|
@@ -51,6 +51,7 @@ def project_back(tensor: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
|
51
51
|
return tensor
|
|
52
52
|
|
|
53
53
|
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
54
|
+
# this is only used once per accumulator to initialize it
|
|
54
55
|
@torch.no_grad
|
|
55
56
|
def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
|
|
56
57
|
"""
|
|
@@ -64,7 +65,19 @@ def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
|
|
|
64
65
|
final.append(None)
|
|
65
66
|
continue
|
|
66
67
|
|
|
67
|
-
|
|
68
|
+
if not torch.isfinite(M).all():
|
|
69
|
+
raise RuntimeError(f"Initial gradient for parameter {M.shape} has non-finite values.")
|
|
70
|
+
|
|
71
|
+
M_f64 = M.to(torch.float64) + 1e-30 * torch.eye(M.shape[0], device=M.device, dtype=torch.float64)
|
|
72
|
+
try:
|
|
73
|
+
_, Q_f64 = torch_linalg.eigh(M_f64)
|
|
74
|
+
except RuntimeError as e:
|
|
75
|
+
if M_f64.is_cpu: raise e
|
|
76
|
+
M_f64 = M_f64.cpu()
|
|
77
|
+
_, Q_f64 = torch_linalg.eigh(M_f64) # apparently there is a bug in CUDA eigh
|
|
78
|
+
Q_f64 = Q_f64.to(M.device)
|
|
79
|
+
|
|
80
|
+
Q = Q_f64.to(M.dtype)
|
|
68
81
|
|
|
69
82
|
Q = torch.flip(Q, [1])
|
|
70
83
|
final.append(Q)
|
|
@@ -156,7 +169,7 @@ class SOAP(TensorTransform):
|
|
|
156
169
|
beta2: float = 0.95,
|
|
157
170
|
shampoo_beta: float | None = 0.95,
|
|
158
171
|
precond_freq: int = 10,
|
|
159
|
-
merge_small: bool =
|
|
172
|
+
merge_small: bool = False,
|
|
160
173
|
max_dim: int = 4096,
|
|
161
174
|
precondition_1d: bool = True,
|
|
162
175
|
eps: float = 1e-8,
|
|
File without changes
|
|
@@ -111,7 +111,7 @@ class GGTBasis(TensorTransform):
|
|
|
111
111
|
inner: Chainable | None = None,
|
|
112
112
|
):
|
|
113
113
|
defaults = locals().copy()
|
|
114
|
-
del defaults['self'], defaults['inner']
|
|
114
|
+
del defaults['self'], defaults['inner'], defaults["basis_opt"]
|
|
115
115
|
|
|
116
116
|
super().__init__(defaults, concat_params=True, inner=inner)
|
|
117
117
|
self.set_child("basis_opt", basis_opt)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/structural_projections.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/forward_gradient.py
RENAMED
|
File without changes
|
{torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/grad_approximator.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -154,8 +154,7 @@ class Online(Module):
|
|
|
154
154
|
closure = objective.closure
|
|
155
155
|
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
156
156
|
|
|
157
|
-
step = self.
|
|
158
|
-
self.global_state['step'] = step
|
|
157
|
+
step = self.increment_counter("step", start = 0)
|
|
159
158
|
|
|
160
159
|
params = TensorList(objective.params)
|
|
161
160
|
p_cur = params.clone()
|
|
@@ -165,7 +164,7 @@ class Online(Module):
|
|
|
165
164
|
var_c = objective.clone(clone_updates=False)
|
|
166
165
|
|
|
167
166
|
# on 1st step just step and store previous params
|
|
168
|
-
if step ==
|
|
167
|
+
if step == 0:
|
|
169
168
|
p_prev.copy_(params)
|
|
170
169
|
|
|
171
170
|
module.update(var_c)
|
|
File without changes
|
|
@@ -53,11 +53,11 @@ _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.T
|
|
|
53
53
|
Filter = _SingleFilter | Iterable[_SingleFilter]
|
|
54
54
|
|
|
55
55
|
def _make_filter(filter: Filter):
|
|
56
|
-
if callable(filter): return filter
|
|
57
56
|
if isinstance(filter, torch.Tensor):
|
|
58
57
|
return lambda x: x is filter
|
|
59
58
|
if isinstance(filter, torch.nn.Module):
|
|
60
59
|
return _make_filter(filter.parameters())
|
|
60
|
+
if callable(filter): return filter
|
|
61
61
|
|
|
62
62
|
# iterable
|
|
63
63
|
filters = [_make_filter(f) for f in filter]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
|
|
7
7
|
from ...core import TensorTransform
|
|
8
8
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
|
-
from ..opt_utils import debias, ema_
|
|
9
|
+
from ..opt_utils import debias as _debias, ema_
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class EMA(TensorTransform):
|
|
@@ -15,13 +15,13 @@ class EMA(TensorTransform):
|
|
|
15
15
|
Args:
|
|
16
16
|
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
17
17
|
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
|
-
|
|
18
|
+
debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
19
|
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
20
|
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
21
21
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
22
22
|
"""
|
|
23
|
-
def __init__(self, momentum:float=0.9, dampening:float=0,
|
|
24
|
-
defaults = dict(momentum=momentum,dampening=dampening,
|
|
23
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
|
|
24
|
+
defaults = dict(momentum=momentum,dampening=dampening,debias=debias,lerp=lerp,ema_init=ema_init)
|
|
25
25
|
super().__init__(defaults, uses_grad=False)
|
|
26
26
|
|
|
27
27
|
self.add_projected_keys("grad", "exp_avg")
|
|
@@ -30,7 +30,7 @@ class EMA(TensorTransform):
|
|
|
30
30
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
31
31
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
32
32
|
|
|
33
|
-
|
|
33
|
+
debias, lerp, ema_init = itemgetter('debias','lerp','ema_init')(settings[0])
|
|
34
34
|
|
|
35
35
|
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
36
36
|
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
@@ -38,7 +38,7 @@ class EMA(TensorTransform):
|
|
|
38
38
|
|
|
39
39
|
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
40
40
|
|
|
41
|
-
if
|
|
41
|
+
if debias: return _debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
|
|
42
42
|
else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
|
|
43
43
|
|
|
44
44
|
|
|
@@ -49,14 +49,14 @@ class HeavyBall(EMA):
|
|
|
49
49
|
Args:
|
|
50
50
|
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
51
51
|
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
52
|
-
|
|
52
|
+
debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
53
53
|
lerp (bool, optional):
|
|
54
54
|
whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
|
|
55
55
|
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
56
56
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
57
57
|
"""
|
|
58
|
-
def __init__(self, momentum:float=0.9, dampening:float=0,
|
|
59
|
-
super().__init__(momentum=momentum, dampening=dampening,
|
|
58
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
|
|
59
|
+
super().__init__(momentum=momentum, dampening=dampening, debias=debias, lerp=lerp, ema_init=ema_init)
|
|
60
60
|
|
|
61
61
|
def nag_(
|
|
62
62
|
tensors_: TensorList,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|