torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +1 -1
- torchzero/__init__.py +3 -1
- torchzero/_minimize/__init__.py +0 -0
- torchzero/_minimize/methods.py +95 -0
- torchzero/_minimize/minimize.py +518 -0
- torchzero/core/__init__.py +5 -5
- torchzero/core/chain.py +2 -1
- torchzero/core/functional.py +2 -1
- torchzero/core/module.py +75 -4
- torchzero/core/transform.py +6 -5
- torchzero/linalg/eigh.py +116 -68
- torchzero/linalg/linear_operator.py +1 -0
- torchzero/linalg/orthogonalize.py +60 -5
- torchzero/linalg/sketch.py +39 -0
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/adaptive/adagrad.py +2 -0
- torchzero/modules/adaptive/adam.py +5 -1
- torchzero/modules/adaptive/adan.py +3 -0
- torchzero/modules/adaptive/ggt.py +20 -18
- torchzero/modules/adaptive/lion.py +3 -1
- torchzero/modules/adaptive/mars.py +6 -5
- torchzero/modules/adaptive/msam.py +3 -0
- torchzero/modules/adaptive/rmsprop.py +2 -0
- torchzero/modules/adaptive/rprop.py +9 -7
- torchzero/modules/adaptive/shampoo.py +9 -1
- torchzero/modules/adaptive/soap.py +32 -29
- torchzero/modules/basis/__init__.py +2 -0
- torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero/modules/basis/soap_basis.py +254 -0
- torchzero/modules/clipping/ema_clipping.py +32 -27
- torchzero/modules/clipping/growth_clipping.py +1 -0
- torchzero/modules/experimental/__init__.py +1 -6
- torchzero/modules/experimental/coordinate_momentum.py +2 -0
- torchzero/modules/experimental/cubic_adam.py +4 -0
- torchzero/modules/grad_approximation/__init__.py +3 -2
- torchzero/modules/least_squares/gn.py +6 -0
- torchzero/modules/misc/gradient_accumulation.py +1 -0
- torchzero/modules/misc/misc.py +6 -0
- torchzero/modules/momentum/averaging.py +6 -0
- torchzero/modules/momentum/momentum.py +13 -9
- torchzero/modules/ops/__init__.py +0 -1
- torchzero/modules/ops/accumulate.py +4 -0
- torchzero/modules/ops/higher_level.py +6 -1
- torchzero/modules/second_order/inm.py +4 -0
- torchzero/modules/second_order/newton.py +11 -3
- torchzero/modules/second_order/newton_cg.py +7 -3
- torchzero/modules/second_order/nystrom.py +14 -19
- torchzero/modules/second_order/rsn.py +37 -6
- torchzero/modules/trust_region/trust_region.py +2 -1
- torchzero/utils/benchmarks/logistic.py +33 -18
- torchzero/utils/optuna_tools.py +1 -1
- torchzero/utils/params.py +13 -1
- torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
- torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero/modules/experimental/eigengrad.py +0 -207
- /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -1,54 +1,58 @@
|
|
|
1
|
-
tests/test_identical.py,sha256=
|
|
1
|
+
tests/test_identical.py,sha256=8Pw52Q19yeK5maYQEd2HYoOMItN599oRMUKzl-EugfQ,11550
|
|
2
2
|
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
3
3
|
tests/test_module_autograd.py,sha256=cncOlxtxmyJQHUd7nL9aWLRAr1kxtlKgVLqP3_qIb2E,21374
|
|
4
4
|
tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
|
|
5
5
|
tests/test_opts.py,sha256=hw7CCw7FD_RJSdiSacyXUSM7DI-_RfP8wJlsz079SNw,44263
|
|
6
6
|
tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
|
|
7
7
|
tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
|
|
8
|
-
torchzero/__init__.py,sha256=
|
|
9
|
-
torchzero/
|
|
10
|
-
torchzero/
|
|
11
|
-
torchzero/
|
|
8
|
+
torchzero/__init__.py,sha256=SZLJgf_sjHyqtTzz0f70AtHP_V_WloX1KQF8mm34zdg,175
|
|
9
|
+
torchzero/_minimize/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
torchzero/_minimize/methods.py,sha256=1oktoSdWtiA0JEF34yTkY3_nPB5Q5ODHl18C0mcglNw,2445
|
|
11
|
+
torchzero/_minimize/minimize.py,sha256=JJBmREQvhDxyqGM62xharsuebyefxRADkd6Bg_TE-DQ,17236
|
|
12
|
+
torchzero/core/__init__.py,sha256=lufcll5r98gTjVfQSvz6-wfI0qMAgZtLLSByHuHTats,358
|
|
13
|
+
torchzero/core/chain.py,sha256=-6vW-L5pzg2Rwpq3LKIAoqJGPvCkHKjt_B1boGikQmM,1900
|
|
14
|
+
torchzero/core/functional.py,sha256=D125Hso8fHMSKlyhkir3GGJzXxuIitXmVhKn2Y9x-Ck,3272
|
|
12
15
|
torchzero/core/modular.py,sha256=Xpp6jfiKArC3Q42G63I9qj3eWcYt-l7d-EIm-59ADcI,9584
|
|
13
|
-
torchzero/core/module.py,sha256=
|
|
16
|
+
torchzero/core/module.py,sha256=DKGLwLWm9LkOBYZHW9QBoXo9eBgnYz7nmoCXJ0gl0e0,21210
|
|
14
17
|
torchzero/core/objective.py,sha256=kEIlry7Bxf_zDUoqAIKUTRvvJmCEpn0Ad2crNt18GCc,40005
|
|
15
18
|
torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
|
|
16
|
-
torchzero/core/transform.py,sha256=
|
|
19
|
+
torchzero/core/transform.py,sha256=WlHoc5cCY1vXQrwMsIG0g3Kle93kBSbrBfxGz5X9_0Q,12251
|
|
17
20
|
torchzero/linalg/__init__.py,sha256=wlry3dbncdsySKk6sSdiRefTcc8dIh4DcA0wFyU1MC8,407
|
|
18
21
|
torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
19
|
-
torchzero/linalg/eigh.py,sha256=
|
|
22
|
+
torchzero/linalg/eigh.py,sha256=l1fX_7hL-DFk8gu20-NuSKDJcRpz58KxUKQHeBhCcHE,9035
|
|
20
23
|
torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
|
|
21
|
-
torchzero/linalg/linear_operator.py,sha256=
|
|
24
|
+
torchzero/linalg/linear_operator.py,sha256=MWTY7DS8B8IkR28kVA9nmoM-OU-1eBsP22iYXkDrj9A,16654
|
|
22
25
|
torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
|
|
23
|
-
torchzero/linalg/orthogonalize.py,sha256=
|
|
26
|
+
torchzero/linalg/orthogonalize.py,sha256=GSvDZA9evTpu3obqCkEocgpDp_91sRexoAwH2q0zTEY,5345
|
|
24
27
|
torchzero/linalg/qr.py,sha256=KykXhSlye0vhyP5JjX6pkPnheHKLLbAKmDff8Hogxyo,2857
|
|
28
|
+
torchzero/linalg/sketch.py,sha256=dKD9t7I7stv089cCvZyPAOZ0D9wzVG1TmV3297w0tk4,1261
|
|
25
29
|
torchzero/linalg/solve.py,sha256=kING1WCioof8_EKgHeyr53dlft_9KtlJnwOWega3DnA,14355
|
|
26
30
|
torchzero/linalg/svd.py,sha256=jmunSxM-twR5VCUGI_gmV3j7QxMJIe1aBoBlJf5i2fo,1432
|
|
27
31
|
torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
|
|
28
|
-
torchzero/modules/__init__.py,sha256=
|
|
32
|
+
torchzero/modules/__init__.py,sha256=ZN20E2ES6zDf5DuFbZpuCKFinFc5eGR1h00iYZ_XBGU,652
|
|
29
33
|
torchzero/modules/opt_utils.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
|
|
30
34
|
torchzero/modules/adaptive/__init__.py,sha256=X8w2Dal3k0WpLQN-WolnWBBgUyIiZF5RnqBlN0dcAYw,1081
|
|
31
|
-
torchzero/modules/adaptive/adagrad.py,sha256=
|
|
35
|
+
torchzero/modules/adaptive/adagrad.py,sha256=NDwmUZaEk0lWnbgYxN23yTWK5A5dQ9BtoKzRTFSKozY,12131
|
|
32
36
|
torchzero/modules/adaptive/adahessian.py,sha256=ucf8loS_lU9VjCb_M42WwXESjPJ_KFChLGkIMFWXO5o,8734
|
|
33
|
-
torchzero/modules/adaptive/adam.py,sha256=
|
|
34
|
-
torchzero/modules/adaptive/adan.py,sha256=
|
|
37
|
+
torchzero/modules/adaptive/adam.py,sha256=RDHYyIAJdi1Pxny8HOHiCFgvPztNwlJlCtzE_ZE-138,3896
|
|
38
|
+
torchzero/modules/adaptive/adan.py,sha256=tmQHiJ5MNwOGP3fp479goHh0xXlhnzULhHxKcVZOkvM,4219
|
|
35
39
|
torchzero/modules/adaptive/adaptive_heavyball.py,sha256=iDiZqke6z6FOR9mhoHMLMm7jvxjzHIQANTe0FBwNj1Q,2230
|
|
36
40
|
torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
|
|
37
41
|
torchzero/modules/adaptive/esgd.py,sha256=gnah-7zk_fMsn7yIWivqDgnaaSdDFXpxg33ywF6TMZg,6173
|
|
38
|
-
torchzero/modules/adaptive/ggt.py,sha256=
|
|
39
|
-
torchzero/modules/adaptive/lion.py,sha256=
|
|
42
|
+
torchzero/modules/adaptive/ggt.py,sha256=7G0Hh8lWy4o73VYVHcZ1JJyDqeKcXi2Y6Qp3qIHosOY,6512
|
|
43
|
+
torchzero/modules/adaptive/lion.py,sha256=yeKUt3WIITtWx97IQzudgbdai77MCfnL_cu90vRkTmA,1141
|
|
40
44
|
torchzero/modules/adaptive/lre_optimizers.py,sha256=AwWUIwnBrozR2HFYLfJnMCBHAWWMKzkS63xFKstRgc0,9760
|
|
41
|
-
torchzero/modules/adaptive/mars.py,sha256=
|
|
45
|
+
torchzero/modules/adaptive/mars.py,sha256=WquKzTnCZcxzslcvSBMFJVz_kjuCuAzlesw1bHnKqOg,2325
|
|
42
46
|
torchzero/modules/adaptive/matrix_momentum.py,sha256=YefF2k746ke7qiiabdhCPCUFB1_fRddAfGCyIOwV3Ok,6789
|
|
43
|
-
torchzero/modules/adaptive/msam.py,sha256=
|
|
47
|
+
torchzero/modules/adaptive/msam.py,sha256=cHfdNkk3Joy2aENwUZXGf3N0P7zcxYGKuySf699OTfM,7051
|
|
44
48
|
torchzero/modules/adaptive/muon.py,sha256=jQ6jlfM4vVRidGJ7FrLtgPnZeuIfW_zU72o7LvOKqh8,8023
|
|
45
49
|
torchzero/modules/adaptive/natural_gradient.py,sha256=8UzacvvIMbYVVE2q0HQ9DLLHYlm1eu6cAiRsOv5XRzQ,7078
|
|
46
50
|
torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
|
|
47
|
-
torchzero/modules/adaptive/rmsprop.py,sha256=
|
|
48
|
-
torchzero/modules/adaptive/rprop.py,sha256=
|
|
51
|
+
torchzero/modules/adaptive/rmsprop.py,sha256=sb709Smpkm8H3vYOsh7BzWni5hAf3nBQevhagyOt7mo,4655
|
|
52
|
+
torchzero/modules/adaptive/rprop.py,sha256=vw-Rufa-gpHgq1gDarmNQexrFr13lPLq_mj3c3pNB_Q,11593
|
|
49
53
|
torchzero/modules/adaptive/sam.py,sha256=CTMCqaH9s5EmKQyj1GpqSeTO1weyfsNWPYFN1xaSm_o,5709
|
|
50
|
-
torchzero/modules/adaptive/shampoo.py,sha256=
|
|
51
|
-
torchzero/modules/adaptive/soap.py,sha256=
|
|
54
|
+
torchzero/modules/adaptive/shampoo.py,sha256=1WpjroFS37HmDLV51iK4d8vtnJWFrGCsDkoQav0p47E,10048
|
|
55
|
+
torchzero/modules/adaptive/soap.py,sha256=jyS6F2o4bMKzMU8H2dDggFQEqMqw4W1rX78u8p3uaV4,12619
|
|
52
56
|
torchzero/modules/adaptive/sophia_h.py,sha256=O_izgGlUgUlpH3Oi5PdCKTyxus4yO1PaJUFhGXuGG9k,7063
|
|
53
57
|
torchzero/modules/adaptive/psgd/__init__.py,sha256=g73mAkWEutwU6jzjiwdbYk5Yxgs4i6QVWefFKkm8cDw,223
|
|
54
58
|
torchzero/modules/adaptive/psgd/_psgd_utils.py,sha256=YtwbUKyVWITZPmpwCBJBC42XQP9HcxNx_znEaIv3hsI,1096
|
|
@@ -58,21 +62,20 @@ torchzero/modules/adaptive/psgd/psgd_kron_newton.py,sha256=oH-oI1pvbR-z6H6ma1O2G
|
|
|
58
62
|
torchzero/modules/adaptive/psgd/psgd_kron_whiten.py,sha256=vmhkY6cKaRE5qzy_4tUkIJp6qC3L6ESZMuiU_ih5tR4,7299
|
|
59
63
|
torchzero/modules/adaptive/psgd/psgd_lra_newton.py,sha256=JL8JmqHgcFqfkX7VeD3sRvNj0xeCuDTHxjNyQ_HigBw,4709
|
|
60
64
|
torchzero/modules/adaptive/psgd/psgd_lra_whiten.py,sha256=SaNYtE4_2tV29CbVaTHi8A6RxmhoMaucF5NoMRg6QaA,4197
|
|
65
|
+
torchzero/modules/basis/__init__.py,sha256=MeXoykwqqmWt-Gx8YWMycVL7m5N4j7Ob_L0GbcwLOfM,65
|
|
66
|
+
torchzero/modules/basis/ggt_basis.py,sha256=NVddZrv58lm7M2Q2j5_3YYLcBYRdeSB_y03bxExSiJs,7772
|
|
67
|
+
torchzero/modules/basis/soap_basis.py,sha256=pwlxIa9lW9V1NcLPmhm--LVbyq7ALSfkV_4b6ki1hO8,10479
|
|
61
68
|
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
62
69
|
torchzero/modules/clipping/clipping.py,sha256=C2dMt0rpuiLMsKq2EWi8qhISSxfCU0nKKGgjWEk2Yxc,14198
|
|
63
|
-
torchzero/modules/clipping/ema_clipping.py,sha256=
|
|
64
|
-
torchzero/modules/clipping/growth_clipping.py,sha256=
|
|
70
|
+
torchzero/modules/clipping/ema_clipping.py,sha256=7lFkQWVkchxlZynYXS4JDjhxB8T5tbE0qsP3GXK6mrA,6916
|
|
71
|
+
torchzero/modules/clipping/growth_clipping.py,sha256=VAmUUeIsSGWrGmZiFAngWUBBsxj4d0QAMf36oAMZL8A,6556
|
|
65
72
|
torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
|
|
66
73
|
torchzero/modules/conjugate_gradient/cg.py,sha256=fcmP77_v_RPpb0sDV2B_90FvFY67FdJt54KHdccY5YU,14540
|
|
67
|
-
torchzero/modules/experimental/__init__.py,sha256=
|
|
68
|
-
torchzero/modules/experimental/
|
|
69
|
-
torchzero/modules/experimental/
|
|
70
|
-
torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
|
|
71
|
-
torchzero/modules/experimental/cubic_adam.py,sha256=RhcHajUfUAcXZDks0X0doR18YtMItQYPmxuEihud4bo,5137
|
|
74
|
+
torchzero/modules/experimental/__init__.py,sha256=najUDh01Av6gEeMYRV9X9lWAr4ZrC6ZgJcPtNpon7ZQ,734
|
|
75
|
+
torchzero/modules/experimental/coordinate_momentum.py,sha256=4BMmgooPysYlX7QOaTUjBn6MNfBAMujM5TCm72vSexw,1152
|
|
76
|
+
torchzero/modules/experimental/cubic_adam.py,sha256=97sgbtkqG1ziXOMxlCor-L-UzzqgSumz8shVOgYL4oQ,5303
|
|
72
77
|
torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
|
|
73
78
|
torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
|
|
74
|
-
torchzero/modules/experimental/eigen_sr1.py,sha256=rCcWVplTWQh91xpgDap35CGEex41C19irUfDlq9lviU,6865
|
|
75
|
-
torchzero/modules/experimental/eigengrad.py,sha256=UPuyo-OmCmu3XLAPclIfsnMN4qcNwX83m7S_55syukA,8455
|
|
76
79
|
torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
|
|
77
80
|
torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
|
|
78
81
|
torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
|
|
@@ -82,15 +85,15 @@ torchzero/modules/experimental/newton_solver.py,sha256=aHZh8EA-QQop3iGz7Ge37KTNg
|
|
|
82
85
|
torchzero/modules/experimental/newtonnewton.py,sha256=TYUuQwHu8bom08czU9lP7MQq5qFBq_JYZTH_Wmm4g-o,3269
|
|
83
86
|
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
|
|
84
87
|
torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
|
|
85
|
-
torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
|
|
86
88
|
torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
|
|
87
|
-
torchzero/modules/grad_approximation/__init__.py,sha256=
|
|
89
|
+
torchzero/modules/grad_approximation/__init__.py,sha256=BAFXc73_ORySVDyXiyZxpusXWn7K66KFT9LZEMwVKes,221
|
|
88
90
|
torchzero/modules/grad_approximation/fdm.py,sha256=hq7U8UkzCfc7z0J1ZmZo9xOLzHHY0uRjebcwZQrBCzA,4376
|
|
89
91
|
torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
|
|
90
92
|
torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
|
|
91
93
|
torchzero/modules/grad_approximation/rfdm.py,sha256=-5zqMB98YNNa1aQXXtf6UNGSJxySO7mn1NksWyPzp3o,19607
|
|
94
|
+
torchzero/modules/grad_approximation/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
|
|
92
95
|
torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
|
|
93
|
-
torchzero/modules/least_squares/gn.py,sha256=
|
|
96
|
+
torchzero/modules/least_squares/gn.py,sha256=hufsWNq_UdEPFDFKNGgCiM4R9739Xu8JqYWSwKkdSZ8,8087
|
|
94
97
|
torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
|
|
95
98
|
torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
|
|
96
99
|
torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
|
|
@@ -102,21 +105,21 @@ torchzero/modules/line_search/strong_wolfe.py,sha256=9jGjxebuXHbl8wEFpvV0s4mMX4J
|
|
|
102
105
|
torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
|
|
103
106
|
torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
|
|
104
107
|
torchzero/modules/misc/escape.py,sha256=c_OMf2jQ7MbxkrXWNmgIpZrBe28N9f89tnzuCQ3fu3A,1930
|
|
105
|
-
torchzero/modules/misc/gradient_accumulation.py,sha256=
|
|
108
|
+
torchzero/modules/misc/gradient_accumulation.py,sha256=1BVqGXwv1YPg7DRJWP0XY6s-vzxrvyXLdruM1Y5KJ5s,2326
|
|
106
109
|
torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
|
|
107
|
-
torchzero/modules/misc/misc.py,sha256=
|
|
110
|
+
torchzero/modules/misc/misc.py,sha256=eWVyYSYiQxcS7G7aVM4nFYiF0csE9qcztTaP4id5CbE,15306
|
|
108
111
|
torchzero/modules/misc/multistep.py,sha256=twdE-lU9Wa0b_uquH9kZ-1OwP0gqWfFMJkdjVWJRwe4,6599
|
|
109
112
|
torchzero/modules/misc/regularization.py,sha256=MCd_tnBYfFnx0b3sM1vHNQ_WbTVfo7l8pxmxGVgWcc0,5935
|
|
110
113
|
torchzero/modules/misc/split.py,sha256=rmi9PgMgiqddrr8fY8Dbdcl2dgwTn9YBAve_bg5Zd08,4288
|
|
111
114
|
torchzero/modules/misc/switch.py,sha256=_ycuD23gR0ZvIUmX3feYBr0_WTX22Pfhu3whpiSCMv4,3678
|
|
112
115
|
torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
|
|
113
|
-
torchzero/modules/momentum/averaging.py,sha256=
|
|
116
|
+
torchzero/modules/momentum/averaging.py,sha256=OTO_LRNiAhbcKTXrWI-uENqIOH_3DX5_1uYJ3eMVcJY,3202
|
|
114
117
|
torchzero/modules/momentum/cautious.py,sha256=1hD2H08OQaNZG52sheRADBsuf9uJsaoLV4n-UVGUH3Y,8379
|
|
115
|
-
torchzero/modules/momentum/momentum.py,sha256=
|
|
116
|
-
torchzero/modules/ops/__init__.py,sha256=
|
|
117
|
-
torchzero/modules/ops/accumulate.py,sha256=
|
|
118
|
+
torchzero/modules/momentum/momentum.py,sha256=aJ8o3gB9HebM9kutpadC5wI0MgMjn-c3J4GF3Z_n0Oc,4484
|
|
119
|
+
torchzero/modules/ops/__init__.py,sha256=p5hwECuODOv6E4H0lETQHweSsUtMlsGE0d8bfTv2Rwc,1225
|
|
120
|
+
torchzero/modules/ops/accumulate.py,sha256=mbJFwykU2fa6IIfsHVXdhmRp7QX1czpCWjw6AYkNn1k,3636
|
|
118
121
|
torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
|
|
119
|
-
torchzero/modules/ops/higher_level.py,sha256=
|
|
122
|
+
torchzero/modules/ops/higher_level.py,sha256=f9DFNI9rnxc-rShAJOfsiwvyGsWu8FsJwJf5yg_V4eg,9366
|
|
120
123
|
torchzero/modules/ops/multi.py,sha256=WzNK07_wL7z0Gb2pmv5a15Oss6tW9IG79x1c4ZPmOqQ,8643
|
|
121
124
|
torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
|
|
122
125
|
torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
|
|
@@ -136,12 +139,12 @@ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6Upy
|
|
|
136
139
|
torchzero/modules/restarts/restars.py,sha256=gcRZ8VHGg60cFVzsk0TWa6-EXoqEFbEeP1p7fs2Av0Q,9348
|
|
137
140
|
torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
|
|
138
141
|
torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
|
|
139
|
-
torchzero/modules/second_order/inm.py,sha256=
|
|
142
|
+
torchzero/modules/second_order/inm.py,sha256=_FnaUHKLl46AtI_XYwF52wtOUbAaO5EMUNRJspX5FEM,3574
|
|
140
143
|
torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
|
|
141
|
-
torchzero/modules/second_order/newton.py,sha256=
|
|
142
|
-
torchzero/modules/second_order/newton_cg.py,sha256=
|
|
143
|
-
torchzero/modules/second_order/nystrom.py,sha256=
|
|
144
|
-
torchzero/modules/second_order/rsn.py,sha256=
|
|
144
|
+
torchzero/modules/second_order/newton.py,sha256=W37_ePdAB1wnlRrNRd2ovNgkbodK1JV8J4SJytVuF_M,11456
|
|
145
|
+
torchzero/modules/second_order/newton_cg.py,sha256=gHmpLRQ2FRr0750gYkFQ7XweJVZmYI6yG9H2vrKvAdA,14925
|
|
146
|
+
torchzero/modules/second_order/nystrom.py,sha256=lGLjtzq2WAWcaT3E6Say82ySZ1yp9I2ASuOqyNTUmiQ,13361
|
|
147
|
+
torchzero/modules/second_order/rsn.py,sha256=13t42cUvY8JQMC4zf4UsqKvpnTXuXZUZJDECCxRYWjg,11286
|
|
145
148
|
torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
|
|
146
149
|
torchzero/modules/smoothing/laplacian.py,sha256=1cewdvnneKn51bbIBqKij0bkveKE7wOYCZ-aGlqzK5M,5201
|
|
147
150
|
torchzero/modules/smoothing/sampling.py,sha256=bCH7wlTYZ_vtKUKSkI6znORxQ5Z6DGcpo10F-GYvFlE,12880
|
|
@@ -155,7 +158,7 @@ torchzero/modules/trust_region/cubic_regularization.py,sha256=QJjLRkfERvOzV5dTdy
|
|
|
155
158
|
torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
|
|
156
159
|
torchzero/modules/trust_region/levenberg_marquardt.py,sha256=-qbeEW3qRKou48bBdZ-u4Nv43TMt475XV6P_aWfxtqE,5039
|
|
157
160
|
torchzero/modules/trust_region/trust_cg.py,sha256=X9rCJQWvptjZVH2H16iekvAYmleKQAYZKRKC3V0JjFY,4455
|
|
158
|
-
torchzero/modules/trust_region/trust_region.py,sha256=
|
|
161
|
+
torchzero/modules/trust_region/trust_region.py,sha256=ax1pJDr3NPLfojUXRMb-hsxD4MpQL1bPAOwozAVTCJI,12930
|
|
159
162
|
torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
|
|
160
163
|
torchzero/modules/variance_reduction/svrg.py,sha256=hXEJ0PUYSksHV0ws3t3cE_4MUTTEn1Htu37iZdDdJCs,8746
|
|
161
164
|
torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
|
|
@@ -195,15 +198,15 @@ torchzero/utils/derivatives.py,sha256=Sc20EH2v2czjH9Z8UChvq0EaYtvOEJKEYOk3fVb0Z6
|
|
|
195
198
|
torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
|
|
196
199
|
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
197
200
|
torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
|
|
198
|
-
torchzero/utils/optuna_tools.py,sha256=
|
|
199
|
-
torchzero/utils/params.py,sha256
|
|
201
|
+
torchzero/utils/optuna_tools.py,sha256=t64nwyuIVP7xgeGVvIGMFBij2j5clhjY4BHtGEnyPVI,1323
|
|
202
|
+
torchzero/utils/params.py,sha256=-amJs518rpI0zzYavTlWrl60JNrgsk1xxdGvIrSw1ZI,6406
|
|
200
203
|
torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
|
|
201
|
-
torchzero/utils/tensorlist.py,sha256=
|
|
204
|
+
torchzero/utils/tensorlist.py,sha256=wpzBJvIAmw9VXsg1UF8gZtq-eh7GlvdM6WL_7NyPYlY,56363
|
|
202
205
|
torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
|
|
203
206
|
torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
|
|
204
207
|
torchzero/utils/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
205
|
-
torchzero/utils/benchmarks/logistic.py,sha256=
|
|
206
|
-
torchzero-0.4.
|
|
207
|
-
torchzero-0.4.
|
|
208
|
-
torchzero-0.4.
|
|
209
|
-
torchzero-0.4.
|
|
208
|
+
torchzero/utils/benchmarks/logistic.py,sha256=1c9kB6tDaKsSNlQn44_Lso2_g-85fQK45RvwLZOcJOo,4587
|
|
209
|
+
torchzero-0.4.3.dist-info/METADATA,sha256=39RK0MpaBQIm0GpIK2YRwoeY5zegEBnJHCZIY4ExQ5k,564
|
|
210
|
+
torchzero-0.4.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
211
|
+
torchzero-0.4.3.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
|
|
212
|
+
torchzero-0.4.3.dist-info/RECORD,,
|
|
@@ -1,258 +0,0 @@
|
|
|
1
|
-
# pylint: disable = non-ascii-name
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import Chainable, TensorTransform
|
|
5
|
-
from ...linalg import (
|
|
6
|
-
OrthogonalizeMethod,
|
|
7
|
-
orthogonalize,
|
|
8
|
-
regularize_eigh,
|
|
9
|
-
torch_linalg,
|
|
10
|
-
)
|
|
11
|
-
from ...linalg.linear_operator import Eigendecomposition
|
|
12
|
-
from ..adaptive.lre_optimizers import LREOptimizerBase
|
|
13
|
-
from .eigengrad import _eigengrad_update_state_, eigengrad_apply
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def weighted_eigen_plus_rank1_mm(
|
|
17
|
-
# A1 = Q1 @ diag(L1) @ Q1.T
|
|
18
|
-
L1: torch.Tensor,
|
|
19
|
-
Q1: torch.Tensor,
|
|
20
|
-
|
|
21
|
-
# K2 = v2 @ v2.T
|
|
22
|
-
v2: torch.Tensor,
|
|
23
|
-
|
|
24
|
-
# second matrix
|
|
25
|
-
B: torch.Tensor,
|
|
26
|
-
|
|
27
|
-
# weights
|
|
28
|
-
w1: float,
|
|
29
|
-
w2: float,
|
|
30
|
-
|
|
31
|
-
) -> torch.Tensor:
|
|
32
|
-
"""
|
|
33
|
-
Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
34
|
-
|
|
35
|
-
Returns ``(n, k)``
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
39
|
-
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
40
|
-
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
|
|
41
|
-
B (torch.Tensor): shape ``(n, k)``.
|
|
42
|
-
w1 (float): weight for A1.
|
|
43
|
-
w2 (float): weight for A2.
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
# sketch A1
|
|
47
|
-
QTB = Q1.T @ B # (rank, k)
|
|
48
|
-
LQTB = L1.unsqueeze(1) * QTB # (rank, k)
|
|
49
|
-
sketch1 = Q1 @ LQTB # (n, k)
|
|
50
|
-
|
|
51
|
-
# skecth A2
|
|
52
|
-
vB = v2 @ B
|
|
53
|
-
sketch2 = v2.outer(vB)
|
|
54
|
-
|
|
55
|
-
return w1 * sketch1 + w2 * sketch2
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def adanystrom_update(
|
|
59
|
-
L1: torch.Tensor,
|
|
60
|
-
Q1: torch.Tensor,
|
|
61
|
-
v2: torch.Tensor,
|
|
62
|
-
w1: float,
|
|
63
|
-
w2: float,
|
|
64
|
-
oversampling_p: int,
|
|
65
|
-
rank: int,
|
|
66
|
-
eig_tol: float,
|
|
67
|
-
damping: float,
|
|
68
|
-
rdamping: float,
|
|
69
|
-
orthogonalize_method: OrthogonalizeMethod,
|
|
70
|
-
|
|
71
|
-
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
72
|
-
"""computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
|
|
73
|
-
where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
74
|
-
|
|
75
|
-
returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
79
|
-
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
80
|
-
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
|
|
81
|
-
w1 (float): weight for A1.
|
|
82
|
-
w2 (float): weight for A2.
|
|
83
|
-
"""
|
|
84
|
-
n = Q1.shape[0]
|
|
85
|
-
device = Q1.device
|
|
86
|
-
dtype = Q1.dtype
|
|
87
|
-
l = rank + oversampling_p
|
|
88
|
-
|
|
89
|
-
# gaussian test matrix
|
|
90
|
-
Omega = torch.randn(n, l, device=device, dtype=dtype)
|
|
91
|
-
|
|
92
|
-
# sketch
|
|
93
|
-
AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
|
|
94
|
-
Q = orthogonalize(AOmega, orthogonalize_method)
|
|
95
|
-
|
|
96
|
-
AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
|
|
97
|
-
QTAQ = Q.T @ AQ
|
|
98
|
-
|
|
99
|
-
W = (QTAQ + QTAQ.T) / 2.0
|
|
100
|
-
|
|
101
|
-
# compute new L and Q
|
|
102
|
-
try:
|
|
103
|
-
L_prime, S = torch_linalg.eigh(W, retry_float64=True)
|
|
104
|
-
except torch.linalg.LinAlgError:
|
|
105
|
-
return L1, Q1
|
|
106
|
-
|
|
107
|
-
L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
|
|
108
|
-
|
|
109
|
-
if L_prime is None or S is None:
|
|
110
|
-
return L1, Q1
|
|
111
|
-
|
|
112
|
-
return L_prime, Q @ S
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
# def adanystrom_update2(
|
|
116
|
-
# L1: torch.Tensor,
|
|
117
|
-
# Q1: torch.Tensor,
|
|
118
|
-
# v2: torch.Tensor,
|
|
119
|
-
# w1: float,
|
|
120
|
-
# w2: float,
|
|
121
|
-
# rank: int,
|
|
122
|
-
# ):
|
|
123
|
-
# def A_mm(X):
|
|
124
|
-
# return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
|
|
125
|
-
|
|
126
|
-
# return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
|
|
127
|
-
|
|
128
|
-
class AdaNystrom(TensorTransform):
|
|
129
|
-
"""Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
rank (_type_): rank of Nyström approximation.
|
|
133
|
-
w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
|
|
134
|
-
w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
|
|
135
|
-
oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
|
|
136
|
-
eig_tol (float, optional):
|
|
137
|
-
removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
|
|
138
|
-
damping (float, optional):
|
|
139
|
-
added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
|
|
140
|
-
rdamping (float, optional):
|
|
141
|
-
added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
|
|
142
|
-
mm_tol (float, optional):
|
|
143
|
-
removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
|
|
144
|
-
mm_truncate (int | None, optional):
|
|
145
|
-
uses top k eigenvalues to compute the update. Defaults to None.
|
|
146
|
-
mm_damping (float, optional):
|
|
147
|
-
added to eigenvalues when computing the update. Defaults to 1e-4.
|
|
148
|
-
mm_rdamping (float, optional):
|
|
149
|
-
added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
|
|
150
|
-
id_reg (float, optional):
|
|
151
|
-
multiplier to identity matrix added to preconditioner before computing update
|
|
152
|
-
If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
|
|
153
|
-
This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
|
|
154
|
-
concat_params (bool, optional):
|
|
155
|
-
whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
|
|
156
|
-
update_freq (int, optional): update frequency. Defaults to 1.
|
|
157
|
-
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
158
|
-
"""
|
|
159
|
-
def __init__(
|
|
160
|
-
self,
|
|
161
|
-
rank:int = 100,
|
|
162
|
-
beta=0.95,
|
|
163
|
-
oversampling: int = 10,
|
|
164
|
-
eig_tol: float | None = 1e-32,
|
|
165
|
-
damping: float = 0,
|
|
166
|
-
rdamping: float = 0,
|
|
167
|
-
mm_tol: float = 0,
|
|
168
|
-
mm_truncate: int | None = None,
|
|
169
|
-
mm_damping: float = 0,
|
|
170
|
-
mm_rdamping: float = 0,
|
|
171
|
-
id_reg: float | None = None,
|
|
172
|
-
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
173
|
-
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
174
|
-
orthogonalize_interval: int | None = 100,
|
|
175
|
-
|
|
176
|
-
concat_params: bool = True,
|
|
177
|
-
update_freq: int = 1,
|
|
178
|
-
inner: Chainable | None = None,
|
|
179
|
-
):
|
|
180
|
-
defaults = locals().copy()
|
|
181
|
-
for k in ["self", "concat_params", "inner", "update_freq"]:
|
|
182
|
-
del defaults[k]
|
|
183
|
-
|
|
184
|
-
super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
|
|
185
|
-
|
|
186
|
-
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
187
|
-
state["step"] = state.get("step", 0) + 1
|
|
188
|
-
rank = setting["rank"]
|
|
189
|
-
device = tensor.device
|
|
190
|
-
dtype = tensor.dtype
|
|
191
|
-
beta = setting["beta"]
|
|
192
|
-
|
|
193
|
-
try:
|
|
194
|
-
if "L" not in state:
|
|
195
|
-
# use just tensor and zero L and Q with zero weight
|
|
196
|
-
|
|
197
|
-
L, Q = adanystrom_update(
|
|
198
|
-
L1=torch.zeros(rank, device=device, dtype=dtype),
|
|
199
|
-
Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
|
|
200
|
-
v2=tensor.ravel(),
|
|
201
|
-
w1=0,
|
|
202
|
-
w2=1-beta,
|
|
203
|
-
rank=rank,
|
|
204
|
-
oversampling_p=setting["oversampling"],
|
|
205
|
-
eig_tol=setting["eig_tol"],
|
|
206
|
-
damping=setting["damping"],
|
|
207
|
-
rdamping=setting["rdamping"],
|
|
208
|
-
orthogonalize_method=setting["orthogonalize_method"],
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
state["L"] = state["L_reg"] = L
|
|
212
|
-
state["Q"] = state["Q_reg"] = Q
|
|
213
|
-
|
|
214
|
-
else:
|
|
215
|
-
L = state["L"]
|
|
216
|
-
Q = state["Q"]
|
|
217
|
-
|
|
218
|
-
w1 = beta
|
|
219
|
-
w2 = 1 - w1
|
|
220
|
-
|
|
221
|
-
# compute new factors (this function truncates them)
|
|
222
|
-
L_new, Q_new = adanystrom_update(
|
|
223
|
-
L1=L,
|
|
224
|
-
Q1=Q,
|
|
225
|
-
v2=tensor.ravel(),
|
|
226
|
-
w1=w1,
|
|
227
|
-
w2=w2,
|
|
228
|
-
rank=rank,
|
|
229
|
-
oversampling_p=setting["oversampling"],
|
|
230
|
-
eig_tol=setting["eig_tol"],
|
|
231
|
-
damping=setting["damping"],
|
|
232
|
-
rdamping=setting["rdamping"],
|
|
233
|
-
orthogonalize_method=setting["orthogonalize_method"],
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
_eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
|
|
237
|
-
|
|
238
|
-
except torch.linalg.LinAlgError:
|
|
239
|
-
pass
|
|
240
|
-
|
|
241
|
-
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
242
|
-
if "L_reg" not in state:
|
|
243
|
-
return tensor.clip(-0.1, 0.1)
|
|
244
|
-
|
|
245
|
-
if "eigenbasis_state" not in state:
|
|
246
|
-
state["eigenbasis_state"] = {}
|
|
247
|
-
|
|
248
|
-
return eigengrad_apply(
|
|
249
|
-
tensor=tensor,
|
|
250
|
-
L_reg = state["L_reg"],
|
|
251
|
-
Q_reg = state["Q_reg"],
|
|
252
|
-
beta = setting["beta"],
|
|
253
|
-
step = state["step"],
|
|
254
|
-
debias = True,
|
|
255
|
-
id_reg = setting["id_reg"],
|
|
256
|
-
eigenbasis_optimizer = setting["eigenbasis_optimizer"],
|
|
257
|
-
eigenbasis_state = state["eigenbasis_state"]
|
|
258
|
-
)
|
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from torchzero.core import Chainable, TensorTransform
|
|
7
|
-
from torchzero.linalg import matrix_power_eigh, torch_linalg, orthogonalize, OrthogonalizeMethod, regularize_eigh
|
|
8
|
-
from torchzero.utils import TensorList, vec_to_tensors_
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def update_subspace_preconditioner_(
|
|
12
|
-
grad: torch.Tensor, # store grads and basis as vectors for matmul
|
|
13
|
-
basis: torch.Tensor, # ndim, k
|
|
14
|
-
accumulator_: torch.Tensor, # k, k
|
|
15
|
-
beta: float | None,
|
|
16
|
-
):
|
|
17
|
-
projected = basis.T @ grad # k
|
|
18
|
-
outer = torch.outer(projected, projected)
|
|
19
|
-
|
|
20
|
-
if beta is None: accumulator_.add_(outer)
|
|
21
|
-
else: accumulator_.lerp_(outer, 1-beta)
|
|
22
|
-
|
|
23
|
-
# yeah so I can also run subspace opts in this basis
|
|
24
|
-
def apply_subspace_preconditioner(
|
|
25
|
-
tensor: torch.Tensor,
|
|
26
|
-
basis: torch.Tensor, # ndim, k
|
|
27
|
-
accumulator: torch.Tensor,
|
|
28
|
-
tol: float,
|
|
29
|
-
truncate: int | None,
|
|
30
|
-
damping: float,
|
|
31
|
-
rdamping: float,
|
|
32
|
-
):
|
|
33
|
-
L, Q = torch_linalg.eigh(accumulator, retry_float64=True)
|
|
34
|
-
L, Q = regularize_eigh(L=L, Q=Q, truncate=truncate, tol=tol, damping=damping, rdamping=rdamping)
|
|
35
|
-
|
|
36
|
-
if L is None or Q is None:
|
|
37
|
-
return tensor.clip(-0.1, 0.1)
|
|
38
|
-
|
|
39
|
-
preconditioner = (Q * L.rsqrt().unsqueeze(-2)) @ Q.mH
|
|
40
|
-
|
|
41
|
-
tensor_projected = basis.T @ tensor # k
|
|
42
|
-
update_projected = preconditioner @ tensor_projected # k
|
|
43
|
-
return basis @ update_projected # d
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class CommonDirectionsWhiten(TensorTransform):
|
|
47
|
-
"""Whitens in subspace spanned by history of gradient differences.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
beta - for preconditioner itself in the basis.
|
|
51
|
-
basis_beta - how much basis is allowed to change.
|
|
52
|
-
"""
|
|
53
|
-
|
|
54
|
-
def __init__(
|
|
55
|
-
self,
|
|
56
|
-
k: int = 100,
|
|
57
|
-
beta: float | None = 0.95,
|
|
58
|
-
basis_beta=0.95,
|
|
59
|
-
tol: float = 1e-7,
|
|
60
|
-
truncate: int | None = None,
|
|
61
|
-
damping: float = 1e-4,
|
|
62
|
-
rdamping: float = 0,
|
|
63
|
-
basis_type: Literal["gradients", "differences"] = "differences",
|
|
64
|
-
orthogonalize_method: OrthogonalizeMethod | None = 'newtonschulz',
|
|
65
|
-
|
|
66
|
-
concat_params: bool = True,
|
|
67
|
-
inner: Chainable | None = None,
|
|
68
|
-
):
|
|
69
|
-
defaults = locals().copy()
|
|
70
|
-
for key in ["self", "inner", "concat_params"]:
|
|
71
|
-
del defaults[key]
|
|
72
|
-
|
|
73
|
-
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
74
|
-
|
|
75
|
-
@torch.no_grad
|
|
76
|
-
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
77
|
-
g = tensor.ravel()
|
|
78
|
-
k = setting['k']
|
|
79
|
-
beta = setting['beta']
|
|
80
|
-
basis_beta = setting['basis_beta']
|
|
81
|
-
step = state.get("step", 0)
|
|
82
|
-
state["step"] = step + 1
|
|
83
|
-
|
|
84
|
-
# initialize history
|
|
85
|
-
if 'history' not in state:
|
|
86
|
-
state['history'] = deque(maxlen=k)
|
|
87
|
-
state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
88
|
-
state['basis'] = torch.zeros(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
89
|
-
|
|
90
|
-
history: deque = state['history']
|
|
91
|
-
accumulator = state['accumulator']
|
|
92
|
-
basis = state['basis']
|
|
93
|
-
history.append(g)
|
|
94
|
-
|
|
95
|
-
# stack history to new basis term, if history isn't full, fill with random vecs
|
|
96
|
-
if len(history) < k:
|
|
97
|
-
basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
98
|
-
history_basis = torch.stack(tuple(history), -1)
|
|
99
|
-
basis_t[:, -len(history):] = history_basis
|
|
100
|
-
|
|
101
|
-
else:
|
|
102
|
-
basis_t = torch.stack(tuple(history), -1)
|
|
103
|
-
|
|
104
|
-
# in this case basis uses differences in gradients except last entry is the gradient
|
|
105
|
-
if setting["basis_type"] == "differences":
|
|
106
|
-
basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
|
|
107
|
-
|
|
108
|
-
# normalize or orthonormalize new basis term
|
|
109
|
-
if setting["orthogonalize_method"] is not None:
|
|
110
|
-
basis_t = orthogonalize(basis_t, method = setting["orthogonalize_method"])
|
|
111
|
-
else:
|
|
112
|
-
basis_t = (basis_t - basis_t.mean()) / basis_t.std().clip(min=torch.finfo(g.dtype).tiny * 2)
|
|
113
|
-
|
|
114
|
-
# lerp basis
|
|
115
|
-
basis.lerp_(basis_t, 1-basis_beta)
|
|
116
|
-
basis = basis / (1 - basis_beta ** (step+1)) # correct bias on basis EMA
|
|
117
|
-
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
118
|
-
|
|
119
|
-
@torch.no_grad
|
|
120
|
-
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
121
|
-
g = tensor.ravel()
|
|
122
|
-
|
|
123
|
-
basis = state['basis']
|
|
124
|
-
accumulator = state['accumulator']
|
|
125
|
-
step = state["step"]
|
|
126
|
-
accumulator = accumulator / (1 - setting["beta"] ** (step+1)) # correct bias on accumulator EMA
|
|
127
|
-
|
|
128
|
-
try:
|
|
129
|
-
preconditioned = apply_subspace_preconditioner(
|
|
130
|
-
g,
|
|
131
|
-
basis,
|
|
132
|
-
accumulator,
|
|
133
|
-
tol=setting["tol"],
|
|
134
|
-
truncate=setting["truncate"],
|
|
135
|
-
damping=setting["damping"],
|
|
136
|
-
rdamping=setting["rdamping"],
|
|
137
|
-
)
|
|
138
|
-
except torch.linalg.LinAlgError:
|
|
139
|
-
preconditioned = g.clip(-0.1, 0.1)
|
|
140
|
-
|
|
141
|
-
return preconditioned.view_as(tensor)
|
|
142
|
-
|