torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
tests/test_identical.py,sha256=Y48_1f5WrltmO8a_-x-9Yltz2ZeMh8N8q3MGjOCkJhA,11552
|
|
2
|
+
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
3
|
+
tests/test_module_autograd.py,sha256=cncOlxtxmyJQHUd7nL9aWLRAr1kxtlKgVLqP3_qIb2E,21374
|
|
4
|
+
tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
|
|
5
|
+
tests/test_opts.py,sha256=hw7CCw7FD_RJSdiSacyXUSM7DI-_RfP8wJlsz079SNw,44263
|
|
6
|
+
tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
|
|
7
|
+
tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
|
|
8
|
+
torchzero/__init__.py,sha256=nit4KxrRoW6hJDGOy0jkphuawY5gAvPqrYY11Yct6fA,133
|
|
9
|
+
torchzero/core/__init__.py,sha256=h9Ck7XX2XuJUTojU2IMa_2TprXZHbgo748txa3z7-2o,341
|
|
10
|
+
torchzero/core/chain.py,sha256=dtFpxnw8vcbi3EeAANXyPtUmyPyv_VuZrTiPlLRmh7c,1899
|
|
11
|
+
torchzero/core/functional.py,sha256=TSygtyQHDhqf998--hF48yIFr-y3Ycz8arjjR8x1ILU,3156
|
|
12
|
+
torchzero/core/modular.py,sha256=Xpp6jfiKArC3Q42G63I9qj3eWcYt-l7d-EIm-59ADcI,9584
|
|
13
|
+
torchzero/core/module.py,sha256=HfbPfxXxgyBf9wQl7Fpw6B6Ux6UYfvPEmITC64ozb_Q,18012
|
|
14
|
+
torchzero/core/objective.py,sha256=kEIlry7Bxf_zDUoqAIKUTRvvJmCEpn0Ad2crNt18GCc,40005
|
|
15
|
+
torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
|
|
16
|
+
torchzero/core/transform.py,sha256=aJRBtvYjKqD-Ic_AkzeSINYDsTaBAErA-kocEl3PHZw,12244
|
|
17
|
+
torchzero/linalg/__init__.py,sha256=wlry3dbncdsySKk6sSdiRefTcc8dIh4DcA0wFyU1MC8,407
|
|
18
|
+
torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
19
|
+
torchzero/linalg/eigh.py,sha256=YC8x5NEWWsnc3suCebnTfeb4lVMhy-H8LGOZbGnwd8A,7902
|
|
20
|
+
torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
|
|
21
|
+
torchzero/linalg/linear_operator.py,sha256=mVEOvu6yY7TYhUdmZm1IAc6_pWnTaykKDgZu_-J-atk,16653
|
|
22
|
+
torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
|
|
23
|
+
torchzero/linalg/orthogonalize.py,sha256=Fv6zv1JvS9AVwjiMVed55J8-pEbVZv7vqoEo5g0Zrv0,3270
|
|
24
|
+
torchzero/linalg/qr.py,sha256=KykXhSlye0vhyP5JjX6pkPnheHKLLbAKmDff8Hogxyo,2857
|
|
25
|
+
torchzero/linalg/solve.py,sha256=kING1WCioof8_EKgHeyr53dlft_9KtlJnwOWega3DnA,14355
|
|
26
|
+
torchzero/linalg/svd.py,sha256=jmunSxM-twR5VCUGI_gmV3j7QxMJIe1aBoBlJf5i2fo,1432
|
|
27
|
+
torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
|
|
28
|
+
torchzero/modules/__init__.py,sha256=dsOalCw-OVkD8rhpQdcODc3Hsd_sQ2_2xVC-J8mlSuk,632
|
|
29
|
+
torchzero/modules/opt_utils.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
|
|
30
|
+
torchzero/modules/adaptive/__init__.py,sha256=X8w2Dal3k0WpLQN-WolnWBBgUyIiZF5RnqBlN0dcAYw,1081
|
|
31
|
+
torchzero/modules/adaptive/adagrad.py,sha256=hMT-Al-vtD6tzPUpQ79LCNko97D7rJN5ji9JOfBqR3k,12015
|
|
32
|
+
torchzero/modules/adaptive/adahessian.py,sha256=ucf8loS_lU9VjCb_M42WwXESjPJ_KFChLGkIMFWXO5o,8734
|
|
33
|
+
torchzero/modules/adaptive/adam.py,sha256=Okm7Sc9fMArQAZ7Ph4Etq68uL-IXKY4YNqHWpTzPoTY,3767
|
|
34
|
+
torchzero/modules/adaptive/adan.py,sha256=965tBUwKy6uDiY2la6fVcGcsvGMs90Zg-ZHPtozJGe4,4110
|
|
35
|
+
torchzero/modules/adaptive/adaptive_heavyball.py,sha256=iDiZqke6z6FOR9mhoHMLMm7jvxjzHIQANTe0FBwNj1Q,2230
|
|
36
|
+
torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
|
|
37
|
+
torchzero/modules/adaptive/esgd.py,sha256=gnah-7zk_fMsn7yIWivqDgnaaSdDFXpxg33ywF6TMZg,6173
|
|
38
|
+
torchzero/modules/adaptive/ggt.py,sha256=eYCeV3GArdLv9WuWeim0V3CHJYl3FVKtrtsGshkqwWg,6608
|
|
39
|
+
torchzero/modules/adaptive/lion.py,sha256=H3aI2qnrMtmkvXcoddzjjxdkoD5cq_QwIkLmd_bVPso,1085
|
|
40
|
+
torchzero/modules/adaptive/lre_optimizers.py,sha256=AwWUIwnBrozR2HFYLfJnMCBHAWWMKzkS63xFKstRgc0,9760
|
|
41
|
+
torchzero/modules/adaptive/mars.py,sha256=w-cK-1tFuR74SY01xS5jsg1b9qs3l8eOptGrUyQ2m80,2261
|
|
42
|
+
torchzero/modules/adaptive/matrix_momentum.py,sha256=YefF2k746ke7qiiabdhCPCUFB1_fRddAfGCyIOwV3Ok,6789
|
|
43
|
+
torchzero/modules/adaptive/msam.py,sha256=nqwjuhBMX2UO-omUIeOcD5ti6PIKfKs-RVCn7ourkKA,6946
|
|
44
|
+
torchzero/modules/adaptive/muon.py,sha256=jQ6jlfM4vVRidGJ7FrLtgPnZeuIfW_zU72o7LvOKqh8,8023
|
|
45
|
+
torchzero/modules/adaptive/natural_gradient.py,sha256=8UzacvvIMbYVVE2q0HQ9DLLHYlm1eu6cAiRsOv5XRzQ,7078
|
|
46
|
+
torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
|
|
47
|
+
torchzero/modules/adaptive/rmsprop.py,sha256=qWVkRmUQ3dui9yBVYtAEll7OlXZDKNT_m70FakTOrTY,4529
|
|
48
|
+
torchzero/modules/adaptive/rprop.py,sha256=a4_UkWse5u2JFAEIlxQqDBUwvUfxh1kNs2ZIhtccnWE,11540
|
|
49
|
+
torchzero/modules/adaptive/sam.py,sha256=CTMCqaH9s5EmKQyj1GpqSeTO1weyfsNWPYFN1xaSm_o,5709
|
|
50
|
+
torchzero/modules/adaptive/shampoo.py,sha256=C_Mo7UFQtDxW4McWJjT731FNAp3g9MqF0Hka54Yi3xQ,9847
|
|
51
|
+
torchzero/modules/adaptive/soap.py,sha256=hz2N6-jUSWU93RNViIS1c-Ue2uKmQx6BxyYg6mEa2fo,12408
|
|
52
|
+
torchzero/modules/adaptive/sophia_h.py,sha256=O_izgGlUgUlpH3Oi5PdCKTyxus4yO1PaJUFhGXuGG9k,7063
|
|
53
|
+
torchzero/modules/adaptive/psgd/__init__.py,sha256=g73mAkWEutwU6jzjiwdbYk5Yxgs4i6QVWefFKkm8cDw,223
|
|
54
|
+
torchzero/modules/adaptive/psgd/_psgd_utils.py,sha256=YtwbUKyVWITZPmpwCBJBC42XQP9HcxNx_znEaIv3hsI,1096
|
|
55
|
+
torchzero/modules/adaptive/psgd/psgd.py,sha256=CVDJI3fcdPCpLCWFY_pvd_COaZMB3nrYOYDdqQcEaOY,73340
|
|
56
|
+
torchzero/modules/adaptive/psgd/psgd_dense_newton.py,sha256=62R9rkWRynrvmTzb12zi5oQoWWdAWz4i1EC4FXa4zv4,7094
|
|
57
|
+
torchzero/modules/adaptive/psgd/psgd_kron_newton.py,sha256=oH-oI1pvbR-z6H6ma1O2GG0nfjx6NWzUHYKY_FAvRpg,8381
|
|
58
|
+
torchzero/modules/adaptive/psgd/psgd_kron_whiten.py,sha256=vmhkY6cKaRE5qzy_4tUkIJp6qC3L6ESZMuiU_ih5tR4,7299
|
|
59
|
+
torchzero/modules/adaptive/psgd/psgd_lra_newton.py,sha256=JL8JmqHgcFqfkX7VeD3sRvNj0xeCuDTHxjNyQ_HigBw,4709
|
|
60
|
+
torchzero/modules/adaptive/psgd/psgd_lra_whiten.py,sha256=SaNYtE4_2tV29CbVaTHi8A6RxmhoMaucF5NoMRg6QaA,4197
|
|
61
|
+
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
62
|
+
torchzero/modules/clipping/clipping.py,sha256=C2dMt0rpuiLMsKq2EWi8qhISSxfCU0nKKGgjWEk2Yxc,14198
|
|
63
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=D4NgXzXYMjK_SKQU3rVoOKzaCd9igGQg_7sXiGMgMqI,6750
|
|
64
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=I1nk5xXBjk0BzWYzMC58LZHouY44myZNIUjM-duv7zc,6508
|
|
65
|
+
torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
|
|
66
|
+
torchzero/modules/conjugate_gradient/cg.py,sha256=fcmP77_v_RPpb0sDV2B_90FvFY67FdJt54KHdccY5YU,14540
|
|
67
|
+
torchzero/modules/experimental/__init__.py,sha256=YbBrWu2vkXHiBcDXmim-Yte4ZxfmQCs_0fCeIArvtnM,942
|
|
68
|
+
torchzero/modules/experimental/adanystrom.py,sha256=fUWPxxi1aJhWme_d31dBG0XxEZY1hJr6AEiFHdFDxCQ,8970
|
|
69
|
+
torchzero/modules/experimental/common_directions_whiten.py,sha256=R_1fQKlvMD99oFrflJLgxl6ObV8jyPc7-NxAUFQeoYA,4941
|
|
70
|
+
torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
|
|
71
|
+
torchzero/modules/experimental/cubic_adam.py,sha256=RhcHajUfUAcXZDks0X0doR18YtMItQYPmxuEihud4bo,5137
|
|
72
|
+
torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
|
|
73
|
+
torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
|
|
74
|
+
torchzero/modules/experimental/eigen_sr1.py,sha256=rCcWVplTWQh91xpgDap35CGEex41C19irUfDlq9lviU,6865
|
|
75
|
+
torchzero/modules/experimental/eigengrad.py,sha256=UPuyo-OmCmu3XLAPclIfsnMN4qcNwX83m7S_55syukA,8455
|
|
76
|
+
torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
|
|
77
|
+
torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
|
|
78
|
+
torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
|
|
79
|
+
torchzero/modules/experimental/l_infinity.py,sha256=zu3aRLmZkU4LxfyToqjU8re0BLMUd6gl16_9AffIfcw,4752
|
|
80
|
+
torchzero/modules/experimental/matrix_nag.py,sha256=fjj07uZbmYvy4AgrtqvX_WLwvK7HBpKx09Q3BK3L0jo,4451
|
|
81
|
+
torchzero/modules/experimental/newton_solver.py,sha256=aHZh8EA-QQop3iGz7Ge37KTNgEnxOr04PVBmIiBxmiQ,4073
|
|
82
|
+
torchzero/modules/experimental/newtonnewton.py,sha256=TYUuQwHu8bom08czU9lP7MQq5qFBq_JYZTH_Wmm4g-o,3269
|
|
83
|
+
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
|
|
84
|
+
torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
|
|
85
|
+
torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
|
|
86
|
+
torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
|
|
87
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
|
|
88
|
+
torchzero/modules/grad_approximation/fdm.py,sha256=hq7U8UkzCfc7z0J1ZmZo9xOLzHHY0uRjebcwZQrBCzA,4376
|
|
89
|
+
torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
|
|
90
|
+
torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
|
|
91
|
+
torchzero/modules/grad_approximation/rfdm.py,sha256=-5zqMB98YNNa1aQXXtf6UNGSJxySO7mn1NksWyPzp3o,19607
|
|
92
|
+
torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
|
|
93
|
+
torchzero/modules/least_squares/gn.py,sha256=3RQ_7e35Ql9uVUUPi34nef9eQNeZ09fldi964V61Tgg,7889
|
|
94
|
+
torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
|
|
95
|
+
torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
|
|
96
|
+
torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
|
|
97
|
+
torchzero/modules/line_search/backtracking.py,sha256=xomhfRmOquGn_liGLS-5PKP89YpAtzFbM3vPTjz0T5A,9051
|
|
98
|
+
torchzero/modules/line_search/interpolation.py,sha256=tHXlZD1MgfYaymhXY75k9CocltwSYDYZg6ENDCEUiss,4942
|
|
99
|
+
torchzero/modules/line_search/line_search.py,sha256=Qcil-WbXnOpdi_DqzvRzXnInedUYNXDglsnYqfGYHdk,12925
|
|
100
|
+
torchzero/modules/line_search/scipy.py,sha256=xQ80h9cSyF4Iorq_1NoJglu_Bx4_KeojulBIxvwU6gQ,2836
|
|
101
|
+
torchzero/modules/line_search/strong_wolfe.py,sha256=9jGjxebuXHbl8wEFpvV0s4mMX4JMx2aVHqFn_8qpX6g,14979
|
|
102
|
+
torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
|
|
103
|
+
torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
|
|
104
|
+
torchzero/modules/misc/escape.py,sha256=c_OMf2jQ7MbxkrXWNmgIpZrBe28N9f89tnzuCQ3fu3A,1930
|
|
105
|
+
torchzero/modules/misc/gradient_accumulation.py,sha256=Xzjt_ulm6Z3mpmtagoUqoefhoeSDVnmX__tVbcI_RQE,2271
|
|
106
|
+
torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
|
|
107
|
+
torchzero/modules/misc/misc.py,sha256=f-3qxBq1KYI3iGYJXzv1cHEJHc0ScEp-vCLCgiaEgJQ,15002
|
|
108
|
+
torchzero/modules/misc/multistep.py,sha256=twdE-lU9Wa0b_uquH9kZ-1OwP0gqWfFMJkdjVWJRwe4,6599
|
|
109
|
+
torchzero/modules/misc/regularization.py,sha256=MCd_tnBYfFnx0b3sM1vHNQ_WbTVfo7l8pxmxGVgWcc0,5935
|
|
110
|
+
torchzero/modules/misc/split.py,sha256=rmi9PgMgiqddrr8fY8Dbdcl2dgwTn9YBAve_bg5Zd08,4288
|
|
111
|
+
torchzero/modules/misc/switch.py,sha256=_ycuD23gR0ZvIUmX3feYBr0_WTX22Pfhu3whpiSCMv4,3678
|
|
112
|
+
torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
|
|
113
|
+
torchzero/modules/momentum/averaging.py,sha256=Q6WLwCJwgNY96YIfQXWpsX-2kDR7n0IOMDfZMvNVc9U,3035
|
|
114
|
+
torchzero/modules/momentum/cautious.py,sha256=1hD2H08OQaNZG52sheRADBsuf9uJsaoLV4n-UVGUH3Y,8379
|
|
115
|
+
torchzero/modules/momentum/momentum.py,sha256=MPHd4TU1bSlEKLGfueNdmaZ13V5J1suW6agBc3SvrTs,4389
|
|
116
|
+
torchzero/modules/ops/__init__.py,sha256=xUYzWWLlSwaT8sw3dWywkALqI6YGCZgptWQJVy83HhM,1249
|
|
117
|
+
torchzero/modules/ops/accumulate.py,sha256=f-Uutg7gNFRobTc5YI9JlfFiSacXmg0gDhIwQNwZSZg,3439
|
|
118
|
+
torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
|
|
119
|
+
torchzero/modules/ops/higher_level.py,sha256=cUh-908S0GWVGekmUN5c_Vx0HP3P2tQoKN3COQM5TaQ,8965
|
|
120
|
+
torchzero/modules/ops/multi.py,sha256=WzNK07_wL7z0Gb2pmv5a15Oss6tW9IG79x1c4ZPmOqQ,8643
|
|
121
|
+
torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
|
|
122
|
+
torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
|
|
123
|
+
torchzero/modules/ops/utility.py,sha256=UkR3BfN_NBsZe78oERnknf6lmgeGNiEbtfUDNgU0YoQ,4423
|
|
124
|
+
torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
|
|
125
|
+
torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
|
|
126
|
+
torchzero/modules/projections/galore.py,sha256=70k30-2RJ3ncTNZpjxBhyYq1yJVFGfS2YUvU2trIK4o,258
|
|
127
|
+
torchzero/modules/projections/projection.py,sha256=FK1gURrOdSaKLVsTHgSeokoTfCGTLYGY9WB4slpclr4,14485
|
|
128
|
+
torchzero/modules/quasi_newton/__init__.py,sha256=o0N_XvzY5qGOsRYXU2bNdF8_zRb5ZZxaL6inTBl2Sk4,536
|
|
129
|
+
torchzero/modules/quasi_newton/damping.py,sha256=bz2GFjasb-2B6bo8dn3-UJ3ddkiUbiJtaQBLqsixv_E,2806
|
|
130
|
+
torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=q6qCYQzWZxqg5S55mmI6Kf0uRoIr_dtSWrFd-2jpe6o,6929
|
|
131
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=ftuY65N7_zJtByQ_JM_sRgxKqW7vNktuMb7qFqDGF8M,11203
|
|
132
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=5gZKoFBVCLg3XeXOQTALAkk3XvU2wMdjDi5J_B7iJLs,8515
|
|
133
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=vjgYWInavb81twwQQWPfK7gj1aT4ljJNEMSnNsDOM9c,46015
|
|
134
|
+
torchzero/modules/quasi_newton/sg2.py,sha256=-d6dVqlt6Xf03tN1a26Jx9YK2kqOS_snuMuIUT3iYqQ,4164
|
|
135
|
+
torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
|
|
136
|
+
torchzero/modules/restarts/restars.py,sha256=gcRZ8VHGg60cFVzsk0TWa6-EXoqEFbEeP1p7fs2Av0Q,9348
|
|
137
|
+
torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
|
|
138
|
+
torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
|
|
139
|
+
torchzero/modules/second_order/inm.py,sha256=OddoZHQfSuFnlx_7Zj2qiVcC2A_9yMVn_0Gy1A7hNAg,3420
|
|
140
|
+
torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
|
|
141
|
+
torchzero/modules/second_order/newton.py,sha256=QcLXsglvf4zJEwR4cldsGVZCABQtxb6U5qVmU3spN_A,11061
|
|
142
|
+
torchzero/modules/second_order/newton_cg.py,sha256=k8G8CSmeIQZObkWVURFnbF_4g2UvJiwh3xToxn7sFJE,14816
|
|
143
|
+
torchzero/modules/second_order/nystrom.py,sha256=WQFfJj0DOfWXyyx36C54m0WqZPIvTTK7n8U7khLhGLg,13359
|
|
144
|
+
torchzero/modules/second_order/rsn.py,sha256=9s-JyJNNeDlIFv8YVGn7y8DGPnP93WJEjpUQXehX3uY,9980
|
|
145
|
+
torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
|
|
146
|
+
torchzero/modules/smoothing/laplacian.py,sha256=1cewdvnneKn51bbIBqKij0bkveKE7wOYCZ-aGlqzK5M,5201
|
|
147
|
+
torchzero/modules/smoothing/sampling.py,sha256=bCH7wlTYZ_vtKUKSkI6znORxQ5Z6DGcpo10F-GYvFlE,12880
|
|
148
|
+
torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
|
|
149
|
+
torchzero/modules/step_size/adaptive.py,sha256=TYLESwhcs6QCR9AvQWtMuU-XJRDmNEmSCnPdBFwijiY,14679
|
|
150
|
+
torchzero/modules/step_size/lr.py,sha256=Pi8B5hhJWq_NdVMdR4reZhqu1jjC2S5TWTzkMKAEMcM,5931
|
|
151
|
+
torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
|
|
152
|
+
torchzero/modules/termination/termination.py,sha256=lJXLmtA84JoK_QHhjBxfW9lkxGIxqg7cE7d755MeduE,6905
|
|
153
|
+
torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
|
|
154
|
+
torchzero/modules/trust_region/cubic_regularization.py,sha256=QJjLRkfERvOzV5dTdyMvUzdVqOMPyKnJ7FC_d5cR5V8,6711
|
|
155
|
+
torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
|
|
156
|
+
torchzero/modules/trust_region/levenberg_marquardt.py,sha256=-qbeEW3qRKou48bBdZ-u4Nv43TMt475XV6P_aWfxtqE,5039
|
|
157
|
+
torchzero/modules/trust_region/trust_cg.py,sha256=X9rCJQWvptjZVH2H16iekvAYmleKQAYZKRKC3V0JjFY,4455
|
|
158
|
+
torchzero/modules/trust_region/trust_region.py,sha256=oXMNIvboz0R_1J0Gfd4IvbnwZFl32csNVv-lTYGB0zk,12913
|
|
159
|
+
torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
|
|
160
|
+
torchzero/modules/variance_reduction/svrg.py,sha256=hXEJ0PUYSksHV0ws3t3cE_4MUTTEn1Htu37iZdDdJCs,8746
|
|
161
|
+
torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
|
|
162
|
+
torchzero/modules/weight_decay/reinit.py,sha256=Ngb3AUfJU-aE-lUSSDmf6rlx2zJUgJ6JrQ3U1DpuEh8,3311
|
|
163
|
+
torchzero/modules/weight_decay/weight_decay.py,sha256=dC0JidMlUU4WzXRTSwQT0Gp45UdKvZLu6JnXveJTR4c,5362
|
|
164
|
+
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
165
|
+
torchzero/modules/wrappers/optim_wrapper.py,sha256=T-DnI_pbNpt1BN9qvW6KOVGly1RixsdSpIC8FtmmDlY,4701
|
|
166
|
+
torchzero/modules/zeroth_order/__init__.py,sha256=1ADUiOHVHzvIP4TpH7_ILmeW2heidfikbf6d5g_1RzY,18
|
|
167
|
+
torchzero/modules/zeroth_order/cd.py,sha256=6mHQC_pInVQ6HmhW5yAKwkPjBVwhUg8uX-J3GEIb9wQ,5018
|
|
168
|
+
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
169
|
+
torchzero/optim/mbs.py,sha256=mUbl9yl7cxNzmp7J6CXkkjcUd3KBF5zPdDjuiDgJxVU,10970
|
|
170
|
+
torchzero/optim/root.py,sha256=MnXytsgTjDKhYw3UKux1O-g4vuHsFQu2Emsq5zphu_8,2308
|
|
171
|
+
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
172
|
+
torchzero/optim/utility/split.py,sha256=uNnKA40OiXUlr-vlHuU_rLEUfXQXgvr6Cd9yGzjWJiA,1702
|
|
173
|
+
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
174
|
+
torchzero/optim/wrappers/directsearch.py,sha256=7f1Zy9ZxftgIFxJ9_1snyMjp8ypB4ULyddOqgxdwkEg,9806
|
|
175
|
+
torchzero/optim/wrappers/fcmaes.py,sha256=Bobx87TWinICd3SyNaKhoplpNbrLdD5vR1J1TZw69Cs,3437
|
|
176
|
+
torchzero/optim/wrappers/mads.py,sha256=Gm1msCsVNJDq-WS913peZ0TzXU_7wgP_sPmMfEK0VIY,2326
|
|
177
|
+
torchzero/optim/wrappers/moors.py,sha256=AjEeovPEvxEGJwbhOaOCrCIxPgwOHilkj-vWb-88Y_0,2073
|
|
178
|
+
torchzero/optim/wrappers/nevergrad.py,sha256=wFjIPOSffdRkxdtbwBMBw1_ubP9kS2z4nHV6YtqUrvw,4442
|
|
179
|
+
torchzero/optim/wrappers/nlopt.py,sha256=FzU_3VuNN0_ngT_FlQs7ZOL_RZcfeFtJNmklqvyzTkw,8694
|
|
180
|
+
torchzero/optim/wrappers/optuna.py,sha256=Gvs8tGGIV-mJZ1nMrLN3Ju3am-lbabLd_-mVRO6JlkY,2252
|
|
181
|
+
torchzero/optim/wrappers/pybobyqa.py,sha256=1eH3HW2HVyNneZ5vB9ch1hb0OdoYrnk4pJmQVRIOAw4,6309
|
|
182
|
+
torchzero/optim/wrappers/wrapper.py,sha256=eEZnqKRc2cSFVVXNyOGGReMX3MvY0eoHSRjFsQ8pCSw,4380
|
|
183
|
+
torchzero/optim/wrappers/scipy/__init__.py,sha256=lAciW9WTZ8JZ9ZFunyi78cOa6URhTL-rIjk1AdqbrxM,262
|
|
184
|
+
torchzero/optim/wrappers/scipy/basin_hopping.py,sha256=11FuVLYDJxbyxW-44UO-JRef7MErGFfM_PqqqADA8RQ,4435
|
|
185
|
+
torchzero/optim/wrappers/scipy/brute.py,sha256=MD5fWcUo0-COrq5_OKM8X5SSMZaN--4ThW8mdISLX00,1213
|
|
186
|
+
torchzero/optim/wrappers/scipy/differential_evolution.py,sha256=GBJSh0K5ueZX3fY9vxTntQnuIoLDjOQyOj6TMCmTGBg,2669
|
|
187
|
+
torchzero/optim/wrappers/scipy/direct.py,sha256=lTkH8aRY7sWkhrtQ8Z-m8fSdhTpsLxOIQSJprSqThFA,1821
|
|
188
|
+
torchzero/optim/wrappers/scipy/dual_annealing.py,sha256=mIdXy8F3YEfrQCXNxps8T33yRDuqJxaaIjzB0wlKo90,4117
|
|
189
|
+
torchzero/optim/wrappers/scipy/experimental.py,sha256=-_82LtrDVg0qRg6E1coXugsCGGcVNXgxv6m5Anaaw24,4833
|
|
190
|
+
torchzero/optim/wrappers/scipy/minimize.py,sha256=6O2Js2ecGZkNE3mqXJR9Rqp5xlh3KIVYBMF82k5hrws,6491
|
|
191
|
+
torchzero/optim/wrappers/scipy/sgho.py,sha256=mKG9uG-zaEzB6mbiZ25FnXdRcfvGYYOHapuHniyJ8r8,4054
|
|
192
|
+
torchzero/utils/__init__.py,sha256=tpy9ti5Ub5d1zQPFODZ4PjmFKNzoZTd-NByd_snlYtk,761
|
|
193
|
+
torchzero/utils/compile.py,sha256=dY9ioWQvJt71HQN_z4kXX4KgtJ0xOW0MKNlZpHD6118,5130
|
|
194
|
+
torchzero/utils/derivatives.py,sha256=Sc20EH2v2czjH9Z8UChvq0EaYtvOEJKEYOk3fVb0Z6M,18085
|
|
195
|
+
torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
|
|
196
|
+
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
197
|
+
torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
|
|
198
|
+
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
199
|
+
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
200
|
+
torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
|
|
201
|
+
torchzero/utils/tensorlist.py,sha256=4rN8gm967pPmtO5kotXqIX7Mal0ps-IHkGBybfeWY4M,56357
|
|
202
|
+
torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
|
|
203
|
+
torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
|
|
204
|
+
torchzero/utils/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
205
|
+
torchzero/utils/benchmarks/logistic.py,sha256=RHsjHEWkPqaag0kt3wfmdddh4DhftcyW9r70tj9OGp4,4382
|
|
206
|
+
torchzero-0.4.1.dist-info/METADATA,sha256=hB0rFqXnaRbwVkFRwTwjXpKnIFLi8MBvLXbgXTuUGWk,564
|
|
207
|
+
torchzero-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
208
|
+
torchzero-0.4.1.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
|
|
209
|
+
torchzero-0.4.1.dist-info/RECORD,,
|
tests/test_vars.py
DELETED
|
@@ -1,185 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
import torch
|
|
3
|
-
from torchzero.core.module import Var
|
|
4
|
-
from torchzero.utils.tensorlist import TensorList
|
|
5
|
-
|
|
6
|
-
@torch.no_grad
|
|
7
|
-
def test_var_get_loss():
|
|
8
|
-
|
|
9
|
-
# ---------------------------- test that it works ---------------------------- #
|
|
10
|
-
params = [torch.tensor(2.0, requires_grad=True)]
|
|
11
|
-
evaluated = False
|
|
12
|
-
|
|
13
|
-
def closure_1(backward=True):
|
|
14
|
-
assert not backward, 'backward = True'
|
|
15
|
-
|
|
16
|
-
# ensure closure only evaluates once
|
|
17
|
-
nonlocal evaluated
|
|
18
|
-
assert evaluated is False, 'closure was evaluated twice'
|
|
19
|
-
evaluated = True
|
|
20
|
-
|
|
21
|
-
loss = params[0]**2
|
|
22
|
-
if backward:
|
|
23
|
-
params[0].grad = None
|
|
24
|
-
loss.backward()
|
|
25
|
-
else:
|
|
26
|
-
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
|
-
return loss
|
|
28
|
-
|
|
29
|
-
var = Var(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
|
-
|
|
31
|
-
assert var.loss is None, var.loss
|
|
32
|
-
|
|
33
|
-
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
34
|
-
assert evaluated, evaluated
|
|
35
|
-
assert loss is var.loss
|
|
36
|
-
assert var.loss == 4.0
|
|
37
|
-
assert var.loss_approx == 4.0
|
|
38
|
-
assert var.grad is None, var.grad
|
|
39
|
-
|
|
40
|
-
# reevaluate, which should just return already evaluated loss
|
|
41
|
-
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
42
|
-
assert var.grad is None, var.grad
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
# ----------------------- test that backward=True works ---------------------- #
|
|
46
|
-
params = [torch.tensor(3.0, requires_grad=True)]
|
|
47
|
-
evaluated = False
|
|
48
|
-
|
|
49
|
-
def closure_2(backward=True):
|
|
50
|
-
# ensure closure only evaluates once
|
|
51
|
-
nonlocal evaluated
|
|
52
|
-
assert evaluated is False, 'closure was evaluated twice'
|
|
53
|
-
evaluated = True
|
|
54
|
-
|
|
55
|
-
loss = params[0] * 2
|
|
56
|
-
if backward:
|
|
57
|
-
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
58
|
-
params[0].grad = None
|
|
59
|
-
loss.backward()
|
|
60
|
-
else:
|
|
61
|
-
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
|
-
return loss
|
|
63
|
-
|
|
64
|
-
var = Var(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
-
assert var.grad is None, var.grad
|
|
66
|
-
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
67
|
-
assert var.grad is not None
|
|
68
|
-
assert var.grad[0] == 2.0, var.grad
|
|
69
|
-
|
|
70
|
-
# reevaluate, which should just return already evaluated loss
|
|
71
|
-
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
72
|
-
assert var.grad[0] == 2.0, var.grad
|
|
73
|
-
|
|
74
|
-
# get grad, which should just return already evaluated grad
|
|
75
|
-
assert (grad := var.get_grad())[0] == 2.0, grad
|
|
76
|
-
assert grad is var.grad, grad
|
|
77
|
-
|
|
78
|
-
# get update, which should create and return cloned grad
|
|
79
|
-
assert var.update is None
|
|
80
|
-
assert (update := var.get_update())[0] == 2.0, update
|
|
81
|
-
assert update is var.update
|
|
82
|
-
assert update is not var.grad
|
|
83
|
-
assert var.grad is not None
|
|
84
|
-
assert update[0] == var.grad[0]
|
|
85
|
-
|
|
86
|
-
@torch.no_grad
|
|
87
|
-
def test_var_get_grad():
|
|
88
|
-
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
|
-
evaluated = False
|
|
90
|
-
|
|
91
|
-
def closure(backward=True):
|
|
92
|
-
# ensure closure only evaluates once
|
|
93
|
-
nonlocal evaluated
|
|
94
|
-
assert evaluated is False, 'closure was evaluated twice'
|
|
95
|
-
evaluated = True
|
|
96
|
-
|
|
97
|
-
loss = params[0]**2
|
|
98
|
-
if backward:
|
|
99
|
-
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
100
|
-
params[0].grad = None
|
|
101
|
-
loss.backward()
|
|
102
|
-
else:
|
|
103
|
-
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
|
-
return loss
|
|
105
|
-
|
|
106
|
-
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
-
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
108
|
-
assert grad is var.grad
|
|
109
|
-
|
|
110
|
-
assert var.loss == 4.0
|
|
111
|
-
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
112
|
-
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
113
|
-
assert var.loss_approx == 4.0
|
|
114
|
-
|
|
115
|
-
assert var.update is None, var.update
|
|
116
|
-
assert (update := var.get_update())[0] == 4.0, update
|
|
117
|
-
|
|
118
|
-
@torch.no_grad
|
|
119
|
-
def test_var_get_update():
|
|
120
|
-
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
|
-
evaluated = False
|
|
122
|
-
|
|
123
|
-
def closure(backward=True):
|
|
124
|
-
# ensure closure only evaluates once
|
|
125
|
-
nonlocal evaluated
|
|
126
|
-
assert evaluated is False, 'closure was evaluated twice'
|
|
127
|
-
evaluated = True
|
|
128
|
-
|
|
129
|
-
loss = params[0]**2
|
|
130
|
-
if backward:
|
|
131
|
-
assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
|
|
132
|
-
params[0].grad = None
|
|
133
|
-
loss.backward()
|
|
134
|
-
else:
|
|
135
|
-
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
|
-
return loss
|
|
137
|
-
|
|
138
|
-
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
-
assert var.update is None, var.update
|
|
140
|
-
assert (update := var.get_update())[0] == 4.0, update
|
|
141
|
-
assert update is var.update
|
|
142
|
-
|
|
143
|
-
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
144
|
-
assert grad is var.grad
|
|
145
|
-
assert grad is not update
|
|
146
|
-
|
|
147
|
-
assert var.loss == 4.0
|
|
148
|
-
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
149
|
-
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
150
|
-
assert var.loss_approx == 4.0
|
|
151
|
-
|
|
152
|
-
assert (update := var.get_update())[0] == 4.0, update
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
|
|
156
|
-
for k,v in v1.__dict__.items():
|
|
157
|
-
if not k.startswith('__'):
|
|
158
|
-
# if k == 'post_step_hooks': continue
|
|
159
|
-
if k == 'storage': continue
|
|
160
|
-
if k == 'update' and clone_update:
|
|
161
|
-
if v1.update is None or v2.update is None:
|
|
162
|
-
assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
163
|
-
else:
|
|
164
|
-
assert (TensorList(v1.update) == TensorList(v2.update)).global_all()
|
|
165
|
-
assert v1.update is not v2.update
|
|
166
|
-
else:
|
|
167
|
-
assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
168
|
-
|
|
169
|
-
def test_var_clone():
|
|
170
|
-
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
171
|
-
def closure(backward): return 1
|
|
172
|
-
var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
173
|
-
|
|
174
|
-
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
175
|
-
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
176
|
-
|
|
177
|
-
var.grad = TensorList(torch.randn(5))
|
|
178
|
-
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
179
|
-
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
180
|
-
|
|
181
|
-
var.update = TensorList(torch.randn(5) * 2)
|
|
182
|
-
var.loss = torch.randn(1)
|
|
183
|
-
var.loss_approx = var.loss
|
|
184
|
-
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
185
|
-
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|