torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
tests/test_identical.py,sha256=Y48_1f5WrltmO8a_-x-9Yltz2ZeMh8N8q3MGjOCkJhA,11552
|
|
2
|
+
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
3
|
+
tests/test_module_autograd.py,sha256=cncOlxtxmyJQHUd7nL9aWLRAr1kxtlKgVLqP3_qIb2E,21374
|
|
4
|
+
tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
|
|
5
|
+
tests/test_opts.py,sha256=hw7CCw7FD_RJSdiSacyXUSM7DI-_RfP8wJlsz079SNw,44263
|
|
6
|
+
tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
|
|
7
|
+
tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
|
|
8
|
+
torchzero/__init__.py,sha256=nit4KxrRoW6hJDGOy0jkphuawY5gAvPqrYY11Yct6fA,133
|
|
9
|
+
torchzero/core/__init__.py,sha256=h9Ck7XX2XuJUTojU2IMa_2TprXZHbgo748txa3z7-2o,341
|
|
10
|
+
torchzero/core/chain.py,sha256=dtFpxnw8vcbi3EeAANXyPtUmyPyv_VuZrTiPlLRmh7c,1899
|
|
11
|
+
torchzero/core/functional.py,sha256=TSygtyQHDhqf998--hF48yIFr-y3Ycz8arjjR8x1ILU,3156
|
|
12
|
+
torchzero/core/modular.py,sha256=Xpp6jfiKArC3Q42G63I9qj3eWcYt-l7d-EIm-59ADcI,9584
|
|
13
|
+
torchzero/core/module.py,sha256=HfbPfxXxgyBf9wQl7Fpw6B6Ux6UYfvPEmITC64ozb_Q,18012
|
|
14
|
+
torchzero/core/objective.py,sha256=kEIlry7Bxf_zDUoqAIKUTRvvJmCEpn0Ad2crNt18GCc,40005
|
|
15
|
+
torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
|
|
16
|
+
torchzero/core/transform.py,sha256=aJRBtvYjKqD-Ic_AkzeSINYDsTaBAErA-kocEl3PHZw,12244
|
|
17
|
+
torchzero/linalg/__init__.py,sha256=wlry3dbncdsySKk6sSdiRefTcc8dIh4DcA0wFyU1MC8,407
|
|
18
|
+
torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
19
|
+
torchzero/linalg/eigh.py,sha256=YC8x5NEWWsnc3suCebnTfeb4lVMhy-H8LGOZbGnwd8A,7902
|
|
20
|
+
torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
|
|
21
|
+
torchzero/linalg/linear_operator.py,sha256=mVEOvu6yY7TYhUdmZm1IAc6_pWnTaykKDgZu_-J-atk,16653
|
|
22
|
+
torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
|
|
23
|
+
torchzero/linalg/orthogonalize.py,sha256=Fv6zv1JvS9AVwjiMVed55J8-pEbVZv7vqoEo5g0Zrv0,3270
|
|
24
|
+
torchzero/linalg/qr.py,sha256=KykXhSlye0vhyP5JjX6pkPnheHKLLbAKmDff8Hogxyo,2857
|
|
25
|
+
torchzero/linalg/solve.py,sha256=kING1WCioof8_EKgHeyr53dlft_9KtlJnwOWega3DnA,14355
|
|
26
|
+
torchzero/linalg/svd.py,sha256=jmunSxM-twR5VCUGI_gmV3j7QxMJIe1aBoBlJf5i2fo,1432
|
|
27
|
+
torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
|
|
28
|
+
torchzero/modules/__init__.py,sha256=dsOalCw-OVkD8rhpQdcODc3Hsd_sQ2_2xVC-J8mlSuk,632
|
|
29
|
+
torchzero/modules/opt_utils.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
|
|
30
|
+
torchzero/modules/adaptive/__init__.py,sha256=X8w2Dal3k0WpLQN-WolnWBBgUyIiZF5RnqBlN0dcAYw,1081
|
|
31
|
+
torchzero/modules/adaptive/adagrad.py,sha256=hMT-Al-vtD6tzPUpQ79LCNko97D7rJN5ji9JOfBqR3k,12015
|
|
32
|
+
torchzero/modules/adaptive/adahessian.py,sha256=ucf8loS_lU9VjCb_M42WwXESjPJ_KFChLGkIMFWXO5o,8734
|
|
33
|
+
torchzero/modules/adaptive/adam.py,sha256=Okm7Sc9fMArQAZ7Ph4Etq68uL-IXKY4YNqHWpTzPoTY,3767
|
|
34
|
+
torchzero/modules/adaptive/adan.py,sha256=965tBUwKy6uDiY2la6fVcGcsvGMs90Zg-ZHPtozJGe4,4110
|
|
35
|
+
torchzero/modules/adaptive/adaptive_heavyball.py,sha256=iDiZqke6z6FOR9mhoHMLMm7jvxjzHIQANTe0FBwNj1Q,2230
|
|
36
|
+
torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
|
|
37
|
+
torchzero/modules/adaptive/esgd.py,sha256=gnah-7zk_fMsn7yIWivqDgnaaSdDFXpxg33ywF6TMZg,6173
|
|
38
|
+
torchzero/modules/adaptive/ggt.py,sha256=eYCeV3GArdLv9WuWeim0V3CHJYl3FVKtrtsGshkqwWg,6608
|
|
39
|
+
torchzero/modules/adaptive/lion.py,sha256=H3aI2qnrMtmkvXcoddzjjxdkoD5cq_QwIkLmd_bVPso,1085
|
|
40
|
+
torchzero/modules/adaptive/lre_optimizers.py,sha256=AwWUIwnBrozR2HFYLfJnMCBHAWWMKzkS63xFKstRgc0,9760
|
|
41
|
+
torchzero/modules/adaptive/mars.py,sha256=w-cK-1tFuR74SY01xS5jsg1b9qs3l8eOptGrUyQ2m80,2261
|
|
42
|
+
torchzero/modules/adaptive/matrix_momentum.py,sha256=YefF2k746ke7qiiabdhCPCUFB1_fRddAfGCyIOwV3Ok,6789
|
|
43
|
+
torchzero/modules/adaptive/msam.py,sha256=nqwjuhBMX2UO-omUIeOcD5ti6PIKfKs-RVCn7ourkKA,6946
|
|
44
|
+
torchzero/modules/adaptive/muon.py,sha256=jQ6jlfM4vVRidGJ7FrLtgPnZeuIfW_zU72o7LvOKqh8,8023
|
|
45
|
+
torchzero/modules/adaptive/natural_gradient.py,sha256=8UzacvvIMbYVVE2q0HQ9DLLHYlm1eu6cAiRsOv5XRzQ,7078
|
|
46
|
+
torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
|
|
47
|
+
torchzero/modules/adaptive/rmsprop.py,sha256=qWVkRmUQ3dui9yBVYtAEll7OlXZDKNT_m70FakTOrTY,4529
|
|
48
|
+
torchzero/modules/adaptive/rprop.py,sha256=a4_UkWse5u2JFAEIlxQqDBUwvUfxh1kNs2ZIhtccnWE,11540
|
|
49
|
+
torchzero/modules/adaptive/sam.py,sha256=CTMCqaH9s5EmKQyj1GpqSeTO1weyfsNWPYFN1xaSm_o,5709
|
|
50
|
+
torchzero/modules/adaptive/shampoo.py,sha256=C_Mo7UFQtDxW4McWJjT731FNAp3g9MqF0Hka54Yi3xQ,9847
|
|
51
|
+
torchzero/modules/adaptive/soap.py,sha256=hz2N6-jUSWU93RNViIS1c-Ue2uKmQx6BxyYg6mEa2fo,12408
|
|
52
|
+
torchzero/modules/adaptive/sophia_h.py,sha256=O_izgGlUgUlpH3Oi5PdCKTyxus4yO1PaJUFhGXuGG9k,7063
|
|
53
|
+
torchzero/modules/adaptive/psgd/__init__.py,sha256=g73mAkWEutwU6jzjiwdbYk5Yxgs4i6QVWefFKkm8cDw,223
|
|
54
|
+
torchzero/modules/adaptive/psgd/_psgd_utils.py,sha256=YtwbUKyVWITZPmpwCBJBC42XQP9HcxNx_znEaIv3hsI,1096
|
|
55
|
+
torchzero/modules/adaptive/psgd/psgd.py,sha256=CVDJI3fcdPCpLCWFY_pvd_COaZMB3nrYOYDdqQcEaOY,73340
|
|
56
|
+
torchzero/modules/adaptive/psgd/psgd_dense_newton.py,sha256=62R9rkWRynrvmTzb12zi5oQoWWdAWz4i1EC4FXa4zv4,7094
|
|
57
|
+
torchzero/modules/adaptive/psgd/psgd_kron_newton.py,sha256=oH-oI1pvbR-z6H6ma1O2GG0nfjx6NWzUHYKY_FAvRpg,8381
|
|
58
|
+
torchzero/modules/adaptive/psgd/psgd_kron_whiten.py,sha256=vmhkY6cKaRE5qzy_4tUkIJp6qC3L6ESZMuiU_ih5tR4,7299
|
|
59
|
+
torchzero/modules/adaptive/psgd/psgd_lra_newton.py,sha256=JL8JmqHgcFqfkX7VeD3sRvNj0xeCuDTHxjNyQ_HigBw,4709
|
|
60
|
+
torchzero/modules/adaptive/psgd/psgd_lra_whiten.py,sha256=SaNYtE4_2tV29CbVaTHi8A6RxmhoMaucF5NoMRg6QaA,4197
|
|
61
|
+
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
62
|
+
torchzero/modules/clipping/clipping.py,sha256=C2dMt0rpuiLMsKq2EWi8qhISSxfCU0nKKGgjWEk2Yxc,14198
|
|
63
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=D4NgXzXYMjK_SKQU3rVoOKzaCd9igGQg_7sXiGMgMqI,6750
|
|
64
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=I1nk5xXBjk0BzWYzMC58LZHouY44myZNIUjM-duv7zc,6508
|
|
65
|
+
torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
|
|
66
|
+
torchzero/modules/conjugate_gradient/cg.py,sha256=fcmP77_v_RPpb0sDV2B_90FvFY67FdJt54KHdccY5YU,14540
|
|
67
|
+
torchzero/modules/experimental/__init__.py,sha256=YbBrWu2vkXHiBcDXmim-Yte4ZxfmQCs_0fCeIArvtnM,942
|
|
68
|
+
torchzero/modules/experimental/adanystrom.py,sha256=fUWPxxi1aJhWme_d31dBG0XxEZY1hJr6AEiFHdFDxCQ,8970
|
|
69
|
+
torchzero/modules/experimental/common_directions_whiten.py,sha256=R_1fQKlvMD99oFrflJLgxl6ObV8jyPc7-NxAUFQeoYA,4941
|
|
70
|
+
torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
|
|
71
|
+
torchzero/modules/experimental/cubic_adam.py,sha256=RhcHajUfUAcXZDks0X0doR18YtMItQYPmxuEihud4bo,5137
|
|
72
|
+
torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
|
|
73
|
+
torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
|
|
74
|
+
torchzero/modules/experimental/eigen_sr1.py,sha256=rCcWVplTWQh91xpgDap35CGEex41C19irUfDlq9lviU,6865
|
|
75
|
+
torchzero/modules/experimental/eigengrad.py,sha256=UPuyo-OmCmu3XLAPclIfsnMN4qcNwX83m7S_55syukA,8455
|
|
76
|
+
torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
|
|
77
|
+
torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
|
|
78
|
+
torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
|
|
79
|
+
torchzero/modules/experimental/l_infinity.py,sha256=zu3aRLmZkU4LxfyToqjU8re0BLMUd6gl16_9AffIfcw,4752
|
|
80
|
+
torchzero/modules/experimental/matrix_nag.py,sha256=fjj07uZbmYvy4AgrtqvX_WLwvK7HBpKx09Q3BK3L0jo,4451
|
|
81
|
+
torchzero/modules/experimental/newton_solver.py,sha256=aHZh8EA-QQop3iGz7Ge37KTNgEnxOr04PVBmIiBxmiQ,4073
|
|
82
|
+
torchzero/modules/experimental/newtonnewton.py,sha256=TYUuQwHu8bom08czU9lP7MQq5qFBq_JYZTH_Wmm4g-o,3269
|
|
83
|
+
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
|
|
84
|
+
torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
|
|
85
|
+
torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
|
|
86
|
+
torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
|
|
87
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
|
|
88
|
+
torchzero/modules/grad_approximation/fdm.py,sha256=hq7U8UkzCfc7z0J1ZmZo9xOLzHHY0uRjebcwZQrBCzA,4376
|
|
89
|
+
torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
|
|
90
|
+
torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
|
|
91
|
+
torchzero/modules/grad_approximation/rfdm.py,sha256=-5zqMB98YNNa1aQXXtf6UNGSJxySO7mn1NksWyPzp3o,19607
|
|
92
|
+
torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
|
|
93
|
+
torchzero/modules/least_squares/gn.py,sha256=3RQ_7e35Ql9uVUUPi34nef9eQNeZ09fldi964V61Tgg,7889
|
|
94
|
+
torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
|
|
95
|
+
torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
|
|
96
|
+
torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
|
|
97
|
+
torchzero/modules/line_search/backtracking.py,sha256=xomhfRmOquGn_liGLS-5PKP89YpAtzFbM3vPTjz0T5A,9051
|
|
98
|
+
torchzero/modules/line_search/interpolation.py,sha256=tHXlZD1MgfYaymhXY75k9CocltwSYDYZg6ENDCEUiss,4942
|
|
99
|
+
torchzero/modules/line_search/line_search.py,sha256=Qcil-WbXnOpdi_DqzvRzXnInedUYNXDglsnYqfGYHdk,12925
|
|
100
|
+
torchzero/modules/line_search/scipy.py,sha256=xQ80h9cSyF4Iorq_1NoJglu_Bx4_KeojulBIxvwU6gQ,2836
|
|
101
|
+
torchzero/modules/line_search/strong_wolfe.py,sha256=9jGjxebuXHbl8wEFpvV0s4mMX4JMx2aVHqFn_8qpX6g,14979
|
|
102
|
+
torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
|
|
103
|
+
torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
|
|
104
|
+
torchzero/modules/misc/escape.py,sha256=c_OMf2jQ7MbxkrXWNmgIpZrBe28N9f89tnzuCQ3fu3A,1930
|
|
105
|
+
torchzero/modules/misc/gradient_accumulation.py,sha256=Xzjt_ulm6Z3mpmtagoUqoefhoeSDVnmX__tVbcI_RQE,2271
|
|
106
|
+
torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
|
|
107
|
+
torchzero/modules/misc/misc.py,sha256=f-3qxBq1KYI3iGYJXzv1cHEJHc0ScEp-vCLCgiaEgJQ,15002
|
|
108
|
+
torchzero/modules/misc/multistep.py,sha256=twdE-lU9Wa0b_uquH9kZ-1OwP0gqWfFMJkdjVWJRwe4,6599
|
|
109
|
+
torchzero/modules/misc/regularization.py,sha256=MCd_tnBYfFnx0b3sM1vHNQ_WbTVfo7l8pxmxGVgWcc0,5935
|
|
110
|
+
torchzero/modules/misc/split.py,sha256=rmi9PgMgiqddrr8fY8Dbdcl2dgwTn9YBAve_bg5Zd08,4288
|
|
111
|
+
torchzero/modules/misc/switch.py,sha256=_ycuD23gR0ZvIUmX3feYBr0_WTX22Pfhu3whpiSCMv4,3678
|
|
112
|
+
torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
|
|
113
|
+
torchzero/modules/momentum/averaging.py,sha256=Q6WLwCJwgNY96YIfQXWpsX-2kDR7n0IOMDfZMvNVc9U,3035
|
|
114
|
+
torchzero/modules/momentum/cautious.py,sha256=1hD2H08OQaNZG52sheRADBsuf9uJsaoLV4n-UVGUH3Y,8379
|
|
115
|
+
torchzero/modules/momentum/momentum.py,sha256=MPHd4TU1bSlEKLGfueNdmaZ13V5J1suW6agBc3SvrTs,4389
|
|
116
|
+
torchzero/modules/ops/__init__.py,sha256=xUYzWWLlSwaT8sw3dWywkALqI6YGCZgptWQJVy83HhM,1249
|
|
117
|
+
torchzero/modules/ops/accumulate.py,sha256=f-Uutg7gNFRobTc5YI9JlfFiSacXmg0gDhIwQNwZSZg,3439
|
|
118
|
+
torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
|
|
119
|
+
torchzero/modules/ops/higher_level.py,sha256=cUh-908S0GWVGekmUN5c_Vx0HP3P2tQoKN3COQM5TaQ,8965
|
|
120
|
+
torchzero/modules/ops/multi.py,sha256=WzNK07_wL7z0Gb2pmv5a15Oss6tW9IG79x1c4ZPmOqQ,8643
|
|
121
|
+
torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
|
|
122
|
+
torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
|
|
123
|
+
torchzero/modules/ops/utility.py,sha256=UkR3BfN_NBsZe78oERnknf6lmgeGNiEbtfUDNgU0YoQ,4423
|
|
124
|
+
torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
|
|
125
|
+
torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
|
|
126
|
+
torchzero/modules/projections/galore.py,sha256=70k30-2RJ3ncTNZpjxBhyYq1yJVFGfS2YUvU2trIK4o,258
|
|
127
|
+
torchzero/modules/projections/projection.py,sha256=FK1gURrOdSaKLVsTHgSeokoTfCGTLYGY9WB4slpclr4,14485
|
|
128
|
+
torchzero/modules/quasi_newton/__init__.py,sha256=o0N_XvzY5qGOsRYXU2bNdF8_zRb5ZZxaL6inTBl2Sk4,536
|
|
129
|
+
torchzero/modules/quasi_newton/damping.py,sha256=bz2GFjasb-2B6bo8dn3-UJ3ddkiUbiJtaQBLqsixv_E,2806
|
|
130
|
+
torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=q6qCYQzWZxqg5S55mmI6Kf0uRoIr_dtSWrFd-2jpe6o,6929
|
|
131
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=ftuY65N7_zJtByQ_JM_sRgxKqW7vNktuMb7qFqDGF8M,11203
|
|
132
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=5gZKoFBVCLg3XeXOQTALAkk3XvU2wMdjDi5J_B7iJLs,8515
|
|
133
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=vjgYWInavb81twwQQWPfK7gj1aT4ljJNEMSnNsDOM9c,46015
|
|
134
|
+
torchzero/modules/quasi_newton/sg2.py,sha256=-d6dVqlt6Xf03tN1a26Jx9YK2kqOS_snuMuIUT3iYqQ,4164
|
|
135
|
+
torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
|
|
136
|
+
torchzero/modules/restarts/restars.py,sha256=gcRZ8VHGg60cFVzsk0TWa6-EXoqEFbEeP1p7fs2Av0Q,9348
|
|
137
|
+
torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
|
|
138
|
+
torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
|
|
139
|
+
torchzero/modules/second_order/inm.py,sha256=OddoZHQfSuFnlx_7Zj2qiVcC2A_9yMVn_0Gy1A7hNAg,3420
|
|
140
|
+
torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
|
|
141
|
+
torchzero/modules/second_order/newton.py,sha256=QcLXsglvf4zJEwR4cldsGVZCABQtxb6U5qVmU3spN_A,11061
|
|
142
|
+
torchzero/modules/second_order/newton_cg.py,sha256=k8G8CSmeIQZObkWVURFnbF_4g2UvJiwh3xToxn7sFJE,14816
|
|
143
|
+
torchzero/modules/second_order/nystrom.py,sha256=WQFfJj0DOfWXyyx36C54m0WqZPIvTTK7n8U7khLhGLg,13359
|
|
144
|
+
torchzero/modules/second_order/rsn.py,sha256=9s-JyJNNeDlIFv8YVGn7y8DGPnP93WJEjpUQXehX3uY,9980
|
|
145
|
+
torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
|
|
146
|
+
torchzero/modules/smoothing/laplacian.py,sha256=1cewdvnneKn51bbIBqKij0bkveKE7wOYCZ-aGlqzK5M,5201
|
|
147
|
+
torchzero/modules/smoothing/sampling.py,sha256=bCH7wlTYZ_vtKUKSkI6znORxQ5Z6DGcpo10F-GYvFlE,12880
|
|
148
|
+
torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
|
|
149
|
+
torchzero/modules/step_size/adaptive.py,sha256=TYLESwhcs6QCR9AvQWtMuU-XJRDmNEmSCnPdBFwijiY,14679
|
|
150
|
+
torchzero/modules/step_size/lr.py,sha256=Pi8B5hhJWq_NdVMdR4reZhqu1jjC2S5TWTzkMKAEMcM,5931
|
|
151
|
+
torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
|
|
152
|
+
torchzero/modules/termination/termination.py,sha256=lJXLmtA84JoK_QHhjBxfW9lkxGIxqg7cE7d755MeduE,6905
|
|
153
|
+
torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
|
|
154
|
+
torchzero/modules/trust_region/cubic_regularization.py,sha256=QJjLRkfERvOzV5dTdyMvUzdVqOMPyKnJ7FC_d5cR5V8,6711
|
|
155
|
+
torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
|
|
156
|
+
torchzero/modules/trust_region/levenberg_marquardt.py,sha256=-qbeEW3qRKou48bBdZ-u4Nv43TMt475XV6P_aWfxtqE,5039
|
|
157
|
+
torchzero/modules/trust_region/trust_cg.py,sha256=X9rCJQWvptjZVH2H16iekvAYmleKQAYZKRKC3V0JjFY,4455
|
|
158
|
+
torchzero/modules/trust_region/trust_region.py,sha256=oXMNIvboz0R_1J0Gfd4IvbnwZFl32csNVv-lTYGB0zk,12913
|
|
159
|
+
torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
|
|
160
|
+
torchzero/modules/variance_reduction/svrg.py,sha256=hXEJ0PUYSksHV0ws3t3cE_4MUTTEn1Htu37iZdDdJCs,8746
|
|
161
|
+
torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
|
|
162
|
+
torchzero/modules/weight_decay/reinit.py,sha256=Ngb3AUfJU-aE-lUSSDmf6rlx2zJUgJ6JrQ3U1DpuEh8,3311
|
|
163
|
+
torchzero/modules/weight_decay/weight_decay.py,sha256=dC0JidMlUU4WzXRTSwQT0Gp45UdKvZLu6JnXveJTR4c,5362
|
|
164
|
+
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
165
|
+
torchzero/modules/wrappers/optim_wrapper.py,sha256=T-DnI_pbNpt1BN9qvW6KOVGly1RixsdSpIC8FtmmDlY,4701
|
|
166
|
+
torchzero/modules/zeroth_order/__init__.py,sha256=1ADUiOHVHzvIP4TpH7_ILmeW2heidfikbf6d5g_1RzY,18
|
|
167
|
+
torchzero/modules/zeroth_order/cd.py,sha256=6mHQC_pInVQ6HmhW5yAKwkPjBVwhUg8uX-J3GEIb9wQ,5018
|
|
168
|
+
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
169
|
+
torchzero/optim/mbs.py,sha256=mUbl9yl7cxNzmp7J6CXkkjcUd3KBF5zPdDjuiDgJxVU,10970
|
|
170
|
+
torchzero/optim/root.py,sha256=MnXytsgTjDKhYw3UKux1O-g4vuHsFQu2Emsq5zphu_8,2308
|
|
171
|
+
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
172
|
+
torchzero/optim/utility/split.py,sha256=uNnKA40OiXUlr-vlHuU_rLEUfXQXgvr6Cd9yGzjWJiA,1702
|
|
173
|
+
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
174
|
+
torchzero/optim/wrappers/directsearch.py,sha256=7f1Zy9ZxftgIFxJ9_1snyMjp8ypB4ULyddOqgxdwkEg,9806
|
|
175
|
+
torchzero/optim/wrappers/fcmaes.py,sha256=Bobx87TWinICd3SyNaKhoplpNbrLdD5vR1J1TZw69Cs,3437
|
|
176
|
+
torchzero/optim/wrappers/mads.py,sha256=Gm1msCsVNJDq-WS913peZ0TzXU_7wgP_sPmMfEK0VIY,2326
|
|
177
|
+
torchzero/optim/wrappers/moors.py,sha256=AjEeovPEvxEGJwbhOaOCrCIxPgwOHilkj-vWb-88Y_0,2073
|
|
178
|
+
torchzero/optim/wrappers/nevergrad.py,sha256=wFjIPOSffdRkxdtbwBMBw1_ubP9kS2z4nHV6YtqUrvw,4442
|
|
179
|
+
torchzero/optim/wrappers/nlopt.py,sha256=FzU_3VuNN0_ngT_FlQs7ZOL_RZcfeFtJNmklqvyzTkw,8694
|
|
180
|
+
torchzero/optim/wrappers/optuna.py,sha256=Gvs8tGGIV-mJZ1nMrLN3Ju3am-lbabLd_-mVRO6JlkY,2252
|
|
181
|
+
torchzero/optim/wrappers/pybobyqa.py,sha256=1eH3HW2HVyNneZ5vB9ch1hb0OdoYrnk4pJmQVRIOAw4,6309
|
|
182
|
+
torchzero/optim/wrappers/wrapper.py,sha256=eEZnqKRc2cSFVVXNyOGGReMX3MvY0eoHSRjFsQ8pCSw,4380
|
|
183
|
+
torchzero/optim/wrappers/scipy/__init__.py,sha256=lAciW9WTZ8JZ9ZFunyi78cOa6URhTL-rIjk1AdqbrxM,262
|
|
184
|
+
torchzero/optim/wrappers/scipy/basin_hopping.py,sha256=11FuVLYDJxbyxW-44UO-JRef7MErGFfM_PqqqADA8RQ,4435
|
|
185
|
+
torchzero/optim/wrappers/scipy/brute.py,sha256=MD5fWcUo0-COrq5_OKM8X5SSMZaN--4ThW8mdISLX00,1213
|
|
186
|
+
torchzero/optim/wrappers/scipy/differential_evolution.py,sha256=GBJSh0K5ueZX3fY9vxTntQnuIoLDjOQyOj6TMCmTGBg,2669
|
|
187
|
+
torchzero/optim/wrappers/scipy/direct.py,sha256=lTkH8aRY7sWkhrtQ8Z-m8fSdhTpsLxOIQSJprSqThFA,1821
|
|
188
|
+
torchzero/optim/wrappers/scipy/dual_annealing.py,sha256=mIdXy8F3YEfrQCXNxps8T33yRDuqJxaaIjzB0wlKo90,4117
|
|
189
|
+
torchzero/optim/wrappers/scipy/experimental.py,sha256=-_82LtrDVg0qRg6E1coXugsCGGcVNXgxv6m5Anaaw24,4833
|
|
190
|
+
torchzero/optim/wrappers/scipy/minimize.py,sha256=6O2Js2ecGZkNE3mqXJR9Rqp5xlh3KIVYBMF82k5hrws,6491
|
|
191
|
+
torchzero/optim/wrappers/scipy/sgho.py,sha256=mKG9uG-zaEzB6mbiZ25FnXdRcfvGYYOHapuHniyJ8r8,4054
|
|
192
|
+
torchzero/utils/__init__.py,sha256=tpy9ti5Ub5d1zQPFODZ4PjmFKNzoZTd-NByd_snlYtk,761
|
|
193
|
+
torchzero/utils/compile.py,sha256=dY9ioWQvJt71HQN_z4kXX4KgtJ0xOW0MKNlZpHD6118,5130
|
|
194
|
+
torchzero/utils/derivatives.py,sha256=Sc20EH2v2czjH9Z8UChvq0EaYtvOEJKEYOk3fVb0Z6M,18085
|
|
195
|
+
torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
|
|
196
|
+
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
197
|
+
torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
|
|
198
|
+
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
199
|
+
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
200
|
+
torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
|
|
201
|
+
torchzero/utils/tensorlist.py,sha256=4rN8gm967pPmtO5kotXqIX7Mal0ps-IHkGBybfeWY4M,56357
|
|
202
|
+
torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
|
|
203
|
+
torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
|
|
204
|
+
torchzero/utils/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
205
|
+
torchzero/utils/benchmarks/logistic.py,sha256=RHsjHEWkPqaag0kt3wfmdddh4DhftcyW9r70tj9OGp4,4382
|
|
206
|
+
torchzero-0.4.1.dist-info/METADATA,sha256=hB0rFqXnaRbwVkFRwTwjXpKnIFLi8MBvLXbgXTuUGWk,564
|
|
207
|
+
torchzero-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
208
|
+
torchzero-0.4.1.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
|
|
209
|
+
torchzero-0.4.1.dist-info/RECORD,,
|
|
@@ -1,241 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from typing import Literal, Any
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from ...core import Chainable, TensorTransform
|
|
7
|
-
from ...linalg import torch_linalg
|
|
8
|
-
|
|
9
|
-
def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, tol):
|
|
10
|
-
"""returns U ``(ndim, rank)``, L ``(rank, )``"""
|
|
11
|
-
if isinstance(history, torch.Tensor):
|
|
12
|
-
M = history
|
|
13
|
-
else:
|
|
14
|
-
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
15
|
-
|
|
16
|
-
MTM = M.T @ M
|
|
17
|
-
if damping != 0:
|
|
18
|
-
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
19
|
-
|
|
20
|
-
try:
|
|
21
|
-
L, Q = torch_linalg.eigh(MTM, retry_float64=True)
|
|
22
|
-
|
|
23
|
-
# truncate to top n largest eigenvalues
|
|
24
|
-
if truncate is not None and truncate > 0:
|
|
25
|
-
# L is ordered in ascending order
|
|
26
|
-
L = L[-truncate:]
|
|
27
|
-
Q = Q[:, -truncate:]
|
|
28
|
-
|
|
29
|
-
# remove small eigenvalues relative to largest
|
|
30
|
-
L_max = L.amax()
|
|
31
|
-
indices = L > tol * L_max
|
|
32
|
-
if indices.any():
|
|
33
|
-
L = L[indices]
|
|
34
|
-
Q = Q[:, indices]
|
|
35
|
-
|
|
36
|
-
U = (M @ Q) * L.rsqrt()
|
|
37
|
-
|
|
38
|
-
if rdamping != 0:
|
|
39
|
-
L.add_(rdamping * L_max)
|
|
40
|
-
|
|
41
|
-
return U, L
|
|
42
|
-
|
|
43
|
-
except torch.linalg.LinAlgError:
|
|
44
|
-
return None, None
|
|
45
|
-
|
|
46
|
-
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor, exp_avg_proj: torch.Tensor | None, beta:float):
|
|
47
|
-
z = U.T @ g
|
|
48
|
-
|
|
49
|
-
if beta != 0:
|
|
50
|
-
if exp_avg_proj is None: exp_avg_proj = torch.zeros_like(z)
|
|
51
|
-
exp_avg_proj.lerp_(z, weight=1-beta)
|
|
52
|
-
z = exp_avg_proj
|
|
53
|
-
|
|
54
|
-
return (U * L.rsqrt()) @ z, exp_avg_proj
|
|
55
|
-
|
|
56
|
-
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
57
|
-
if value is None: return
|
|
58
|
-
if (key not in state_) or (beta is None): state_[key] = value
|
|
59
|
-
else:
|
|
60
|
-
if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
|
|
61
|
-
else: state_[key].lerp_(value, 1-beta)
|
|
62
|
-
|
|
63
|
-
class LMAdagrad(TensorTransform):
|
|
64
|
-
"""
|
|
65
|
-
Limited-memory full matrix Adagrad.
|
|
66
|
-
|
|
67
|
-
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
68
|
-
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
69
|
-
|
|
70
|
-
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
74
|
-
beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
|
|
75
|
-
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
76
|
-
damping (float, optional): damping value. Defaults to 1e-4.
|
|
77
|
-
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
78
|
-
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
79
|
-
truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
|
|
80
|
-
tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
|
|
81
|
-
order (int, optional):
|
|
82
|
-
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
83
|
-
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
84
|
-
L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
|
|
85
|
-
concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
|
|
86
|
-
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
87
|
-
|
|
88
|
-
## Examples:
|
|
89
|
-
|
|
90
|
-
Limited-memory Adagrad
|
|
91
|
-
|
|
92
|
-
```python
|
|
93
|
-
optimizer = tz.Modular(
|
|
94
|
-
model.parameters(),
|
|
95
|
-
tz.m.LMAdagrad(),
|
|
96
|
-
tz.m.LR(0.1)
|
|
97
|
-
)
|
|
98
|
-
```
|
|
99
|
-
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
100
|
-
|
|
101
|
-
```python
|
|
102
|
-
optimizer = tz.Modular(
|
|
103
|
-
model.parameters(),
|
|
104
|
-
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
105
|
-
tz.m.Debias(0.9, 0.999),
|
|
106
|
-
tz.m.LR(0.01)
|
|
107
|
-
)
|
|
108
|
-
```
|
|
109
|
-
|
|
110
|
-
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
111
|
-
|
|
112
|
-
```python
|
|
113
|
-
optimizer = tz.Modular(
|
|
114
|
-
model.parameters(),
|
|
115
|
-
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
116
|
-
tz.m.Debias(0.9, 0.999),
|
|
117
|
-
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
118
|
-
tz.m.LR(0.01)
|
|
119
|
-
)
|
|
120
|
-
```
|
|
121
|
-
Reference:
|
|
122
|
-
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
123
|
-
"""
|
|
124
|
-
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
history_size: int = 100,
|
|
128
|
-
beta: float = 0.0,
|
|
129
|
-
update_freq: int = 1,
|
|
130
|
-
damping: float = 1e-4,
|
|
131
|
-
rdamping: float = 0,
|
|
132
|
-
truncate: int | None = None,
|
|
133
|
-
tol: float = 1e-7,
|
|
134
|
-
order: int = 1,
|
|
135
|
-
U_beta: float | None = None,
|
|
136
|
-
L_beta: float | None = None,
|
|
137
|
-
concat_params: bool = True,
|
|
138
|
-
|
|
139
|
-
inner: Chainable | None = None,
|
|
140
|
-
U_tfm: Chainable | None = None,
|
|
141
|
-
L_tfm: Chainable | None = None,
|
|
142
|
-
):
|
|
143
|
-
defaults = locals().copy()
|
|
144
|
-
del defaults['self'], defaults['inner'], defaults['concat_params'], defaults["U_tfm"], defaults["L_tfm"]
|
|
145
|
-
|
|
146
|
-
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
147
|
-
|
|
148
|
-
self.set_child("U", U_tfm)
|
|
149
|
-
self.set_child("L", L_tfm)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@torch.no_grad
|
|
153
|
-
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
154
|
-
order = setting['order']
|
|
155
|
-
history_size = setting['history_size']
|
|
156
|
-
update_freq = setting['update_freq']
|
|
157
|
-
U_beta = setting['U_beta']
|
|
158
|
-
L_beta = setting['L_beta']
|
|
159
|
-
|
|
160
|
-
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
161
|
-
history = state['history']
|
|
162
|
-
|
|
163
|
-
if order == 1:
|
|
164
|
-
t = tensor.clone().view(-1)
|
|
165
|
-
history.append(t)
|
|
166
|
-
else:
|
|
167
|
-
|
|
168
|
-
# if order=2, history is of gradient differences, order 3 is differences between differences, etc
|
|
169
|
-
# scaled by parameter differences
|
|
170
|
-
cur_p = param.clone()
|
|
171
|
-
cur_g = tensor.clone()
|
|
172
|
-
eps = torch.finfo(cur_p.dtype).tiny * 2
|
|
173
|
-
for i in range(1, order):
|
|
174
|
-
if f'prev_g_{i}' not in state:
|
|
175
|
-
state[f'prev_p_{i}'] = cur_p
|
|
176
|
-
state[f'prev_g_{i}'] = cur_g
|
|
177
|
-
break
|
|
178
|
-
|
|
179
|
-
s = cur_p - state[f'prev_p_{i}']
|
|
180
|
-
y = cur_g - state[f'prev_g_{i}']
|
|
181
|
-
state[f'prev_p_{i}'] = cur_p
|
|
182
|
-
state[f'prev_g_{i}'] = cur_g
|
|
183
|
-
cur_p = s
|
|
184
|
-
cur_g = y
|
|
185
|
-
|
|
186
|
-
if i == order - 1:
|
|
187
|
-
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
|
|
188
|
-
history.append(cur_g.view(-1))
|
|
189
|
-
|
|
190
|
-
step = state.get('step', 0)
|
|
191
|
-
if step % update_freq == 0 and len(history) != 0:
|
|
192
|
-
|
|
193
|
-
# if maintaining momentum, unproject exp_avg before updating factors and reproject
|
|
194
|
-
exp_avg_proj = state.get("exp_avg_proj", None)
|
|
195
|
-
exp_avg = None
|
|
196
|
-
if exp_avg_proj is not None and "U" in state:
|
|
197
|
-
exp_avg = state["U"] @ exp_avg_proj
|
|
198
|
-
|
|
199
|
-
# update factors
|
|
200
|
-
U, L = lm_adagrad_update(
|
|
201
|
-
history,
|
|
202
|
-
damping=setting["damping"],
|
|
203
|
-
rdamping=setting["rdamping"],
|
|
204
|
-
truncate=setting["truncate"],
|
|
205
|
-
tol=setting["tol"],
|
|
206
|
-
)
|
|
207
|
-
maybe_lerp_(state, U_beta, 'U', U)
|
|
208
|
-
maybe_lerp_(state, L_beta, 'L', L)
|
|
209
|
-
|
|
210
|
-
# re-project exp_avg with new factors
|
|
211
|
-
if U is not None and exp_avg_proj is not None:
|
|
212
|
-
assert exp_avg is not None
|
|
213
|
-
state["exp_avg_proj"] = U.T @ exp_avg
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
if len(history) != 0:
|
|
217
|
-
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
218
|
-
|
|
219
|
-
@torch.no_grad
|
|
220
|
-
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
221
|
-
U = state.get('U', None)
|
|
222
|
-
if U is None:
|
|
223
|
-
# make a conservative step to avoid issues due to different GD scaling
|
|
224
|
-
return tensor.clip_(-0.1, 0.1)
|
|
225
|
-
|
|
226
|
-
# -------------------------------- transforms -------------------------------- #
|
|
227
|
-
L = state['L']
|
|
228
|
-
if "L" in self.children:
|
|
229
|
-
if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
|
|
230
|
-
L = self.inner_step_tensors("L", [L], clone=True)[0]
|
|
231
|
-
|
|
232
|
-
if "U" in self.children:
|
|
233
|
-
if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
|
|
234
|
-
U = self.inner_step_tensors("U", [U], clone=True)[0]
|
|
235
|
-
|
|
236
|
-
# ------------------------------- precondition ------------------------------- #
|
|
237
|
-
g = tensor.view(-1)
|
|
238
|
-
exp_avg_proj = state.get("exp_avg_proj", None)
|
|
239
|
-
update, state["exp_avg_proj"] = lm_adagrad_apply(g, U, L, exp_avg_proj, beta=setting["beta"])
|
|
240
|
-
return update.view_as(tensor)
|
|
241
|
-
|