torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/test_identical.py +1 -1
  2. torchzero/__init__.py +3 -1
  3. torchzero/_minimize/__init__.py +0 -0
  4. torchzero/_minimize/methods.py +95 -0
  5. torchzero/_minimize/minimize.py +518 -0
  6. torchzero/core/__init__.py +5 -5
  7. torchzero/core/chain.py +2 -1
  8. torchzero/core/functional.py +2 -1
  9. torchzero/core/module.py +75 -4
  10. torchzero/core/transform.py +6 -5
  11. torchzero/linalg/eigh.py +116 -68
  12. torchzero/linalg/linear_operator.py +1 -0
  13. torchzero/linalg/orthogonalize.py +60 -5
  14. torchzero/linalg/sketch.py +39 -0
  15. torchzero/modules/__init__.py +1 -0
  16. torchzero/modules/adaptive/adagrad.py +2 -0
  17. torchzero/modules/adaptive/adam.py +5 -1
  18. torchzero/modules/adaptive/adan.py +3 -0
  19. torchzero/modules/adaptive/ggt.py +20 -18
  20. torchzero/modules/adaptive/lion.py +3 -1
  21. torchzero/modules/adaptive/mars.py +6 -5
  22. torchzero/modules/adaptive/msam.py +3 -0
  23. torchzero/modules/adaptive/rmsprop.py +2 -0
  24. torchzero/modules/adaptive/rprop.py +9 -7
  25. torchzero/modules/adaptive/shampoo.py +9 -1
  26. torchzero/modules/adaptive/soap.py +32 -29
  27. torchzero/modules/basis/__init__.py +2 -0
  28. torchzero/modules/basis/ggt_basis.py +199 -0
  29. torchzero/modules/basis/soap_basis.py +254 -0
  30. torchzero/modules/clipping/ema_clipping.py +32 -27
  31. torchzero/modules/clipping/growth_clipping.py +1 -0
  32. torchzero/modules/experimental/__init__.py +1 -6
  33. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  34. torchzero/modules/experimental/cubic_adam.py +4 -0
  35. torchzero/modules/grad_approximation/__init__.py +3 -2
  36. torchzero/modules/least_squares/gn.py +6 -0
  37. torchzero/modules/misc/gradient_accumulation.py +1 -0
  38. torchzero/modules/misc/misc.py +6 -0
  39. torchzero/modules/momentum/averaging.py +6 -0
  40. torchzero/modules/momentum/momentum.py +13 -9
  41. torchzero/modules/ops/__init__.py +0 -1
  42. torchzero/modules/ops/accumulate.py +4 -0
  43. torchzero/modules/ops/higher_level.py +6 -1
  44. torchzero/modules/second_order/inm.py +4 -0
  45. torchzero/modules/second_order/newton.py +11 -3
  46. torchzero/modules/second_order/newton_cg.py +7 -3
  47. torchzero/modules/second_order/nystrom.py +14 -19
  48. torchzero/modules/second_order/rsn.py +37 -6
  49. torchzero/modules/trust_region/trust_region.py +2 -1
  50. torchzero/utils/benchmarks/logistic.py +33 -18
  51. torchzero/utils/optuna_tools.py +1 -1
  52. torchzero/utils/params.py +13 -1
  53. torchzero/utils/tensorlist.py +2 -2
  54. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
  55. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
  56. torchzero/modules/experimental/adanystrom.py +0 -258
  57. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  58. torchzero/modules/experimental/eigen_sr1.py +0 -182
  59. torchzero/modules/experimental/eigengrad.py +0 -207
  60. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  61. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
  62. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,54 +1,58 @@
1
- tests/test_identical.py,sha256=Y48_1f5WrltmO8a_-x-9Yltz2ZeMh8N8q3MGjOCkJhA,11552
1
+ tests/test_identical.py,sha256=8Pw52Q19yeK5maYQEd2HYoOMItN599oRMUKzl-EugfQ,11550
2
2
  tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
3
3
  tests/test_module_autograd.py,sha256=cncOlxtxmyJQHUd7nL9aWLRAr1kxtlKgVLqP3_qIb2E,21374
4
4
  tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
5
5
  tests/test_opts.py,sha256=hw7CCw7FD_RJSdiSacyXUSM7DI-_RfP8wJlsz079SNw,44263
6
6
  tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
7
7
  tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
8
- torchzero/__init__.py,sha256=nit4KxrRoW6hJDGOy0jkphuawY5gAvPqrYY11Yct6fA,133
9
- torchzero/core/__init__.py,sha256=h9Ck7XX2XuJUTojU2IMa_2TprXZHbgo748txa3z7-2o,341
10
- torchzero/core/chain.py,sha256=dtFpxnw8vcbi3EeAANXyPtUmyPyv_VuZrTiPlLRmh7c,1899
11
- torchzero/core/functional.py,sha256=TSygtyQHDhqf998--hF48yIFr-y3Ycz8arjjR8x1ILU,3156
8
+ torchzero/__init__.py,sha256=SZLJgf_sjHyqtTzz0f70AtHP_V_WloX1KQF8mm34zdg,175
9
+ torchzero/_minimize/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ torchzero/_minimize/methods.py,sha256=1oktoSdWtiA0JEF34yTkY3_nPB5Q5ODHl18C0mcglNw,2445
11
+ torchzero/_minimize/minimize.py,sha256=JJBmREQvhDxyqGM62xharsuebyefxRADkd6Bg_TE-DQ,17236
12
+ torchzero/core/__init__.py,sha256=lufcll5r98gTjVfQSvz6-wfI0qMAgZtLLSByHuHTats,358
13
+ torchzero/core/chain.py,sha256=-6vW-L5pzg2Rwpq3LKIAoqJGPvCkHKjt_B1boGikQmM,1900
14
+ torchzero/core/functional.py,sha256=D125Hso8fHMSKlyhkir3GGJzXxuIitXmVhKn2Y9x-Ck,3272
12
15
  torchzero/core/modular.py,sha256=Xpp6jfiKArC3Q42G63I9qj3eWcYt-l7d-EIm-59ADcI,9584
13
- torchzero/core/module.py,sha256=HfbPfxXxgyBf9wQl7Fpw6B6Ux6UYfvPEmITC64ozb_Q,18012
16
+ torchzero/core/module.py,sha256=DKGLwLWm9LkOBYZHW9QBoXo9eBgnYz7nmoCXJ0gl0e0,21210
14
17
  torchzero/core/objective.py,sha256=kEIlry7Bxf_zDUoqAIKUTRvvJmCEpn0Ad2crNt18GCc,40005
15
18
  torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
16
- torchzero/core/transform.py,sha256=aJRBtvYjKqD-Ic_AkzeSINYDsTaBAErA-kocEl3PHZw,12244
19
+ torchzero/core/transform.py,sha256=WlHoc5cCY1vXQrwMsIG0g3Kle93kBSbrBfxGz5X9_0Q,12251
17
20
  torchzero/linalg/__init__.py,sha256=wlry3dbncdsySKk6sSdiRefTcc8dIh4DcA0wFyU1MC8,407
18
21
  torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
19
- torchzero/linalg/eigh.py,sha256=YC8x5NEWWsnc3suCebnTfeb4lVMhy-H8LGOZbGnwd8A,7902
22
+ torchzero/linalg/eigh.py,sha256=l1fX_7hL-DFk8gu20-NuSKDJcRpz58KxUKQHeBhCcHE,9035
20
23
  torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
21
- torchzero/linalg/linear_operator.py,sha256=mVEOvu6yY7TYhUdmZm1IAc6_pWnTaykKDgZu_-J-atk,16653
24
+ torchzero/linalg/linear_operator.py,sha256=MWTY7DS8B8IkR28kVA9nmoM-OU-1eBsP22iYXkDrj9A,16654
22
25
  torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
23
- torchzero/linalg/orthogonalize.py,sha256=Fv6zv1JvS9AVwjiMVed55J8-pEbVZv7vqoEo5g0Zrv0,3270
26
+ torchzero/linalg/orthogonalize.py,sha256=GSvDZA9evTpu3obqCkEocgpDp_91sRexoAwH2q0zTEY,5345
24
27
  torchzero/linalg/qr.py,sha256=KykXhSlye0vhyP5JjX6pkPnheHKLLbAKmDff8Hogxyo,2857
28
+ torchzero/linalg/sketch.py,sha256=dKD9t7I7stv089cCvZyPAOZ0D9wzVG1TmV3297w0tk4,1261
25
29
  torchzero/linalg/solve.py,sha256=kING1WCioof8_EKgHeyr53dlft_9KtlJnwOWega3DnA,14355
26
30
  torchzero/linalg/svd.py,sha256=jmunSxM-twR5VCUGI_gmV3j7QxMJIe1aBoBlJf5i2fo,1432
27
31
  torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
28
- torchzero/modules/__init__.py,sha256=dsOalCw-OVkD8rhpQdcODc3Hsd_sQ2_2xVC-J8mlSuk,632
32
+ torchzero/modules/__init__.py,sha256=ZN20E2ES6zDf5DuFbZpuCKFinFc5eGR1h00iYZ_XBGU,652
29
33
  torchzero/modules/opt_utils.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
30
34
  torchzero/modules/adaptive/__init__.py,sha256=X8w2Dal3k0WpLQN-WolnWBBgUyIiZF5RnqBlN0dcAYw,1081
31
- torchzero/modules/adaptive/adagrad.py,sha256=hMT-Al-vtD6tzPUpQ79LCNko97D7rJN5ji9JOfBqR3k,12015
35
+ torchzero/modules/adaptive/adagrad.py,sha256=NDwmUZaEk0lWnbgYxN23yTWK5A5dQ9BtoKzRTFSKozY,12131
32
36
  torchzero/modules/adaptive/adahessian.py,sha256=ucf8loS_lU9VjCb_M42WwXESjPJ_KFChLGkIMFWXO5o,8734
33
- torchzero/modules/adaptive/adam.py,sha256=Okm7Sc9fMArQAZ7Ph4Etq68uL-IXKY4YNqHWpTzPoTY,3767
34
- torchzero/modules/adaptive/adan.py,sha256=965tBUwKy6uDiY2la6fVcGcsvGMs90Zg-ZHPtozJGe4,4110
37
+ torchzero/modules/adaptive/adam.py,sha256=RDHYyIAJdi1Pxny8HOHiCFgvPztNwlJlCtzE_ZE-138,3896
38
+ torchzero/modules/adaptive/adan.py,sha256=tmQHiJ5MNwOGP3fp479goHh0xXlhnzULhHxKcVZOkvM,4219
35
39
  torchzero/modules/adaptive/adaptive_heavyball.py,sha256=iDiZqke6z6FOR9mhoHMLMm7jvxjzHIQANTe0FBwNj1Q,2230
36
40
  torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
37
41
  torchzero/modules/adaptive/esgd.py,sha256=gnah-7zk_fMsn7yIWivqDgnaaSdDFXpxg33ywF6TMZg,6173
38
- torchzero/modules/adaptive/ggt.py,sha256=eYCeV3GArdLv9WuWeim0V3CHJYl3FVKtrtsGshkqwWg,6608
39
- torchzero/modules/adaptive/lion.py,sha256=H3aI2qnrMtmkvXcoddzjjxdkoD5cq_QwIkLmd_bVPso,1085
42
+ torchzero/modules/adaptive/ggt.py,sha256=7G0Hh8lWy4o73VYVHcZ1JJyDqeKcXi2Y6Qp3qIHosOY,6512
43
+ torchzero/modules/adaptive/lion.py,sha256=yeKUt3WIITtWx97IQzudgbdai77MCfnL_cu90vRkTmA,1141
40
44
  torchzero/modules/adaptive/lre_optimizers.py,sha256=AwWUIwnBrozR2HFYLfJnMCBHAWWMKzkS63xFKstRgc0,9760
41
- torchzero/modules/adaptive/mars.py,sha256=w-cK-1tFuR74SY01xS5jsg1b9qs3l8eOptGrUyQ2m80,2261
45
+ torchzero/modules/adaptive/mars.py,sha256=WquKzTnCZcxzslcvSBMFJVz_kjuCuAzlesw1bHnKqOg,2325
42
46
  torchzero/modules/adaptive/matrix_momentum.py,sha256=YefF2k746ke7qiiabdhCPCUFB1_fRddAfGCyIOwV3Ok,6789
43
- torchzero/modules/adaptive/msam.py,sha256=nqwjuhBMX2UO-omUIeOcD5ti6PIKfKs-RVCn7ourkKA,6946
47
+ torchzero/modules/adaptive/msam.py,sha256=cHfdNkk3Joy2aENwUZXGf3N0P7zcxYGKuySf699OTfM,7051
44
48
  torchzero/modules/adaptive/muon.py,sha256=jQ6jlfM4vVRidGJ7FrLtgPnZeuIfW_zU72o7LvOKqh8,8023
45
49
  torchzero/modules/adaptive/natural_gradient.py,sha256=8UzacvvIMbYVVE2q0HQ9DLLHYlm1eu6cAiRsOv5XRzQ,7078
46
50
  torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
47
- torchzero/modules/adaptive/rmsprop.py,sha256=qWVkRmUQ3dui9yBVYtAEll7OlXZDKNT_m70FakTOrTY,4529
48
- torchzero/modules/adaptive/rprop.py,sha256=a4_UkWse5u2JFAEIlxQqDBUwvUfxh1kNs2ZIhtccnWE,11540
51
+ torchzero/modules/adaptive/rmsprop.py,sha256=sb709Smpkm8H3vYOsh7BzWni5hAf3nBQevhagyOt7mo,4655
52
+ torchzero/modules/adaptive/rprop.py,sha256=vw-Rufa-gpHgq1gDarmNQexrFr13lPLq_mj3c3pNB_Q,11593
49
53
  torchzero/modules/adaptive/sam.py,sha256=CTMCqaH9s5EmKQyj1GpqSeTO1weyfsNWPYFN1xaSm_o,5709
50
- torchzero/modules/adaptive/shampoo.py,sha256=C_Mo7UFQtDxW4McWJjT731FNAp3g9MqF0Hka54Yi3xQ,9847
51
- torchzero/modules/adaptive/soap.py,sha256=hz2N6-jUSWU93RNViIS1c-Ue2uKmQx6BxyYg6mEa2fo,12408
54
+ torchzero/modules/adaptive/shampoo.py,sha256=1WpjroFS37HmDLV51iK4d8vtnJWFrGCsDkoQav0p47E,10048
55
+ torchzero/modules/adaptive/soap.py,sha256=jyS6F2o4bMKzMU8H2dDggFQEqMqw4W1rX78u8p3uaV4,12619
52
56
  torchzero/modules/adaptive/sophia_h.py,sha256=O_izgGlUgUlpH3Oi5PdCKTyxus4yO1PaJUFhGXuGG9k,7063
53
57
  torchzero/modules/adaptive/psgd/__init__.py,sha256=g73mAkWEutwU6jzjiwdbYk5Yxgs4i6QVWefFKkm8cDw,223
54
58
  torchzero/modules/adaptive/psgd/_psgd_utils.py,sha256=YtwbUKyVWITZPmpwCBJBC42XQP9HcxNx_znEaIv3hsI,1096
@@ -58,21 +62,20 @@ torchzero/modules/adaptive/psgd/psgd_kron_newton.py,sha256=oH-oI1pvbR-z6H6ma1O2G
58
62
  torchzero/modules/adaptive/psgd/psgd_kron_whiten.py,sha256=vmhkY6cKaRE5qzy_4tUkIJp6qC3L6ESZMuiU_ih5tR4,7299
59
63
  torchzero/modules/adaptive/psgd/psgd_lra_newton.py,sha256=JL8JmqHgcFqfkX7VeD3sRvNj0xeCuDTHxjNyQ_HigBw,4709
60
64
  torchzero/modules/adaptive/psgd/psgd_lra_whiten.py,sha256=SaNYtE4_2tV29CbVaTHi8A6RxmhoMaucF5NoMRg6QaA,4197
65
+ torchzero/modules/basis/__init__.py,sha256=MeXoykwqqmWt-Gx8YWMycVL7m5N4j7Ob_L0GbcwLOfM,65
66
+ torchzero/modules/basis/ggt_basis.py,sha256=NVddZrv58lm7M2Q2j5_3YYLcBYRdeSB_y03bxExSiJs,7772
67
+ torchzero/modules/basis/soap_basis.py,sha256=pwlxIa9lW9V1NcLPmhm--LVbyq7ALSfkV_4b6ki1hO8,10479
61
68
  torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
62
69
  torchzero/modules/clipping/clipping.py,sha256=C2dMt0rpuiLMsKq2EWi8qhISSxfCU0nKKGgjWEk2Yxc,14198
63
- torchzero/modules/clipping/ema_clipping.py,sha256=D4NgXzXYMjK_SKQU3rVoOKzaCd9igGQg_7sXiGMgMqI,6750
64
- torchzero/modules/clipping/growth_clipping.py,sha256=I1nk5xXBjk0BzWYzMC58LZHouY44myZNIUjM-duv7zc,6508
70
+ torchzero/modules/clipping/ema_clipping.py,sha256=7lFkQWVkchxlZynYXS4JDjhxB8T5tbE0qsP3GXK6mrA,6916
71
+ torchzero/modules/clipping/growth_clipping.py,sha256=VAmUUeIsSGWrGmZiFAngWUBBsxj4d0QAMf36oAMZL8A,6556
65
72
  torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
66
73
  torchzero/modules/conjugate_gradient/cg.py,sha256=fcmP77_v_RPpb0sDV2B_90FvFY67FdJt54KHdccY5YU,14540
67
- torchzero/modules/experimental/__init__.py,sha256=YbBrWu2vkXHiBcDXmim-Yte4ZxfmQCs_0fCeIArvtnM,942
68
- torchzero/modules/experimental/adanystrom.py,sha256=fUWPxxi1aJhWme_d31dBG0XxEZY1hJr6AEiFHdFDxCQ,8970
69
- torchzero/modules/experimental/common_directions_whiten.py,sha256=R_1fQKlvMD99oFrflJLgxl6ObV8jyPc7-NxAUFQeoYA,4941
70
- torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
71
- torchzero/modules/experimental/cubic_adam.py,sha256=RhcHajUfUAcXZDks0X0doR18YtMItQYPmxuEihud4bo,5137
74
+ torchzero/modules/experimental/__init__.py,sha256=najUDh01Av6gEeMYRV9X9lWAr4ZrC6ZgJcPtNpon7ZQ,734
75
+ torchzero/modules/experimental/coordinate_momentum.py,sha256=4BMmgooPysYlX7QOaTUjBn6MNfBAMujM5TCm72vSexw,1152
76
+ torchzero/modules/experimental/cubic_adam.py,sha256=97sgbtkqG1ziXOMxlCor-L-UzzqgSumz8shVOgYL4oQ,5303
72
77
  torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
73
78
  torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
74
- torchzero/modules/experimental/eigen_sr1.py,sha256=rCcWVplTWQh91xpgDap35CGEex41C19irUfDlq9lviU,6865
75
- torchzero/modules/experimental/eigengrad.py,sha256=UPuyo-OmCmu3XLAPclIfsnMN4qcNwX83m7S_55syukA,8455
76
79
  torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
77
80
  torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
78
81
  torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
@@ -82,15 +85,15 @@ torchzero/modules/experimental/newton_solver.py,sha256=aHZh8EA-QQop3iGz7Ge37KTNg
82
85
  torchzero/modules/experimental/newtonnewton.py,sha256=TYUuQwHu8bom08czU9lP7MQq5qFBq_JYZTH_Wmm4g-o,3269
83
86
  torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
84
87
  torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
85
- torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
86
88
  torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
87
- torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
89
+ torchzero/modules/grad_approximation/__init__.py,sha256=BAFXc73_ORySVDyXiyZxpusXWn7K66KFT9LZEMwVKes,221
88
90
  torchzero/modules/grad_approximation/fdm.py,sha256=hq7U8UkzCfc7z0J1ZmZo9xOLzHHY0uRjebcwZQrBCzA,4376
89
91
  torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
90
92
  torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
91
93
  torchzero/modules/grad_approximation/rfdm.py,sha256=-5zqMB98YNNa1aQXXtf6UNGSJxySO7mn1NksWyPzp3o,19607
94
+ torchzero/modules/grad_approximation/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
92
95
  torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
93
- torchzero/modules/least_squares/gn.py,sha256=3RQ_7e35Ql9uVUUPi34nef9eQNeZ09fldi964V61Tgg,7889
96
+ torchzero/modules/least_squares/gn.py,sha256=hufsWNq_UdEPFDFKNGgCiM4R9739Xu8JqYWSwKkdSZ8,8087
94
97
  torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
95
98
  torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
96
99
  torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
@@ -102,21 +105,21 @@ torchzero/modules/line_search/strong_wolfe.py,sha256=9jGjxebuXHbl8wEFpvV0s4mMX4J
102
105
  torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
103
106
  torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
104
107
  torchzero/modules/misc/escape.py,sha256=c_OMf2jQ7MbxkrXWNmgIpZrBe28N9f89tnzuCQ3fu3A,1930
105
- torchzero/modules/misc/gradient_accumulation.py,sha256=Xzjt_ulm6Z3mpmtagoUqoefhoeSDVnmX__tVbcI_RQE,2271
108
+ torchzero/modules/misc/gradient_accumulation.py,sha256=1BVqGXwv1YPg7DRJWP0XY6s-vzxrvyXLdruM1Y5KJ5s,2326
106
109
  torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
107
- torchzero/modules/misc/misc.py,sha256=f-3qxBq1KYI3iGYJXzv1cHEJHc0ScEp-vCLCgiaEgJQ,15002
110
+ torchzero/modules/misc/misc.py,sha256=eWVyYSYiQxcS7G7aVM4nFYiF0csE9qcztTaP4id5CbE,15306
108
111
  torchzero/modules/misc/multistep.py,sha256=twdE-lU9Wa0b_uquH9kZ-1OwP0gqWfFMJkdjVWJRwe4,6599
109
112
  torchzero/modules/misc/regularization.py,sha256=MCd_tnBYfFnx0b3sM1vHNQ_WbTVfo7l8pxmxGVgWcc0,5935
110
113
  torchzero/modules/misc/split.py,sha256=rmi9PgMgiqddrr8fY8Dbdcl2dgwTn9YBAve_bg5Zd08,4288
111
114
  torchzero/modules/misc/switch.py,sha256=_ycuD23gR0ZvIUmX3feYBr0_WTX22Pfhu3whpiSCMv4,3678
112
115
  torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
113
- torchzero/modules/momentum/averaging.py,sha256=Q6WLwCJwgNY96YIfQXWpsX-2kDR7n0IOMDfZMvNVc9U,3035
116
+ torchzero/modules/momentum/averaging.py,sha256=OTO_LRNiAhbcKTXrWI-uENqIOH_3DX5_1uYJ3eMVcJY,3202
114
117
  torchzero/modules/momentum/cautious.py,sha256=1hD2H08OQaNZG52sheRADBsuf9uJsaoLV4n-UVGUH3Y,8379
115
- torchzero/modules/momentum/momentum.py,sha256=MPHd4TU1bSlEKLGfueNdmaZ13V5J1suW6agBc3SvrTs,4389
116
- torchzero/modules/ops/__init__.py,sha256=xUYzWWLlSwaT8sw3dWywkALqI6YGCZgptWQJVy83HhM,1249
117
- torchzero/modules/ops/accumulate.py,sha256=f-Uutg7gNFRobTc5YI9JlfFiSacXmg0gDhIwQNwZSZg,3439
118
+ torchzero/modules/momentum/momentum.py,sha256=aJ8o3gB9HebM9kutpadC5wI0MgMjn-c3J4GF3Z_n0Oc,4484
119
+ torchzero/modules/ops/__init__.py,sha256=p5hwECuODOv6E4H0lETQHweSsUtMlsGE0d8bfTv2Rwc,1225
120
+ torchzero/modules/ops/accumulate.py,sha256=mbJFwykU2fa6IIfsHVXdhmRp7QX1czpCWjw6AYkNn1k,3636
118
121
  torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
119
- torchzero/modules/ops/higher_level.py,sha256=cUh-908S0GWVGekmUN5c_Vx0HP3P2tQoKN3COQM5TaQ,8965
122
+ torchzero/modules/ops/higher_level.py,sha256=f9DFNI9rnxc-rShAJOfsiwvyGsWu8FsJwJf5yg_V4eg,9366
120
123
  torchzero/modules/ops/multi.py,sha256=WzNK07_wL7z0Gb2pmv5a15Oss6tW9IG79x1c4ZPmOqQ,8643
121
124
  torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
122
125
  torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
@@ -136,12 +139,12 @@ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6Upy
136
139
  torchzero/modules/restarts/restars.py,sha256=gcRZ8VHGg60cFVzsk0TWa6-EXoqEFbEeP1p7fs2Av0Q,9348
137
140
  torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
138
141
  torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
139
- torchzero/modules/second_order/inm.py,sha256=OddoZHQfSuFnlx_7Zj2qiVcC2A_9yMVn_0Gy1A7hNAg,3420
142
+ torchzero/modules/second_order/inm.py,sha256=_FnaUHKLl46AtI_XYwF52wtOUbAaO5EMUNRJspX5FEM,3574
140
143
  torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
141
- torchzero/modules/second_order/newton.py,sha256=QcLXsglvf4zJEwR4cldsGVZCABQtxb6U5qVmU3spN_A,11061
142
- torchzero/modules/second_order/newton_cg.py,sha256=k8G8CSmeIQZObkWVURFnbF_4g2UvJiwh3xToxn7sFJE,14816
143
- torchzero/modules/second_order/nystrom.py,sha256=WQFfJj0DOfWXyyx36C54m0WqZPIvTTK7n8U7khLhGLg,13359
144
- torchzero/modules/second_order/rsn.py,sha256=9s-JyJNNeDlIFv8YVGn7y8DGPnP93WJEjpUQXehX3uY,9980
144
+ torchzero/modules/second_order/newton.py,sha256=W37_ePdAB1wnlRrNRd2ovNgkbodK1JV8J4SJytVuF_M,11456
145
+ torchzero/modules/second_order/newton_cg.py,sha256=gHmpLRQ2FRr0750gYkFQ7XweJVZmYI6yG9H2vrKvAdA,14925
146
+ torchzero/modules/second_order/nystrom.py,sha256=lGLjtzq2WAWcaT3E6Say82ySZ1yp9I2ASuOqyNTUmiQ,13361
147
+ torchzero/modules/second_order/rsn.py,sha256=13t42cUvY8JQMC4zf4UsqKvpnTXuXZUZJDECCxRYWjg,11286
145
148
  torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
146
149
  torchzero/modules/smoothing/laplacian.py,sha256=1cewdvnneKn51bbIBqKij0bkveKE7wOYCZ-aGlqzK5M,5201
147
150
  torchzero/modules/smoothing/sampling.py,sha256=bCH7wlTYZ_vtKUKSkI6znORxQ5Z6DGcpo10F-GYvFlE,12880
@@ -155,7 +158,7 @@ torchzero/modules/trust_region/cubic_regularization.py,sha256=QJjLRkfERvOzV5dTdy
155
158
  torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
156
159
  torchzero/modules/trust_region/levenberg_marquardt.py,sha256=-qbeEW3qRKou48bBdZ-u4Nv43TMt475XV6P_aWfxtqE,5039
157
160
  torchzero/modules/trust_region/trust_cg.py,sha256=X9rCJQWvptjZVH2H16iekvAYmleKQAYZKRKC3V0JjFY,4455
158
- torchzero/modules/trust_region/trust_region.py,sha256=oXMNIvboz0R_1J0Gfd4IvbnwZFl32csNVv-lTYGB0zk,12913
161
+ torchzero/modules/trust_region/trust_region.py,sha256=ax1pJDr3NPLfojUXRMb-hsxD4MpQL1bPAOwozAVTCJI,12930
159
162
  torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
160
163
  torchzero/modules/variance_reduction/svrg.py,sha256=hXEJ0PUYSksHV0ws3t3cE_4MUTTEn1Htu37iZdDdJCs,8746
161
164
  torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
@@ -195,15 +198,15 @@ torchzero/utils/derivatives.py,sha256=Sc20EH2v2czjH9Z8UChvq0EaYtvOEJKEYOk3fVb0Z6
195
198
  torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
196
199
  torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
197
200
  torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
198
- torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
199
- torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
201
+ torchzero/utils/optuna_tools.py,sha256=t64nwyuIVP7xgeGVvIGMFBij2j5clhjY4BHtGEnyPVI,1323
202
+ torchzero/utils/params.py,sha256=-amJs518rpI0zzYavTlWrl60JNrgsk1xxdGvIrSw1ZI,6406
200
203
  torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
201
- torchzero/utils/tensorlist.py,sha256=4rN8gm967pPmtO5kotXqIX7Mal0ps-IHkGBybfeWY4M,56357
204
+ torchzero/utils/tensorlist.py,sha256=wpzBJvIAmw9VXsg1UF8gZtq-eh7GlvdM6WL_7NyPYlY,56363
202
205
  torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
203
206
  torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
204
207
  torchzero/utils/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
205
- torchzero/utils/benchmarks/logistic.py,sha256=RHsjHEWkPqaag0kt3wfmdddh4DhftcyW9r70tj9OGp4,4382
206
- torchzero-0.4.1.dist-info/METADATA,sha256=hB0rFqXnaRbwVkFRwTwjXpKnIFLi8MBvLXbgXTuUGWk,564
207
- torchzero-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
208
- torchzero-0.4.1.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
209
- torchzero-0.4.1.dist-info/RECORD,,
208
+ torchzero/utils/benchmarks/logistic.py,sha256=1c9kB6tDaKsSNlQn44_Lso2_g-85fQK45RvwLZOcJOo,4587
209
+ torchzero-0.4.3.dist-info/METADATA,sha256=39RK0MpaBQIm0GpIK2YRwoeY5zegEBnJHCZIY4ExQ5k,564
210
+ torchzero-0.4.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
211
+ torchzero-0.4.3.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
212
+ torchzero-0.4.3.dist-info/RECORD,,
@@ -1,258 +0,0 @@
1
- # pylint: disable = non-ascii-name
2
- import torch
3
-
4
- from ...core import Chainable, TensorTransform
5
- from ...linalg import (
6
- OrthogonalizeMethod,
7
- orthogonalize,
8
- regularize_eigh,
9
- torch_linalg,
10
- )
11
- from ...linalg.linear_operator import Eigendecomposition
12
- from ..adaptive.lre_optimizers import LREOptimizerBase
13
- from .eigengrad import _eigengrad_update_state_, eigengrad_apply
14
-
15
-
16
- def weighted_eigen_plus_rank1_mm(
17
- # A1 = Q1 @ diag(L1) @ Q1.T
18
- L1: torch.Tensor,
19
- Q1: torch.Tensor,
20
-
21
- # K2 = v2 @ v2.T
22
- v2: torch.Tensor,
23
-
24
- # second matrix
25
- B: torch.Tensor,
26
-
27
- # weights
28
- w1: float,
29
- w2: float,
30
-
31
- ) -> torch.Tensor:
32
- """
33
- Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
34
-
35
- Returns ``(n, k)``
36
-
37
- Args:
38
- L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
39
- Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
40
- v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
41
- B (torch.Tensor): shape ``(n, k)``.
42
- w1 (float): weight for A1.
43
- w2 (float): weight for A2.
44
-
45
- """
46
- # sketch A1
47
- QTB = Q1.T @ B # (rank, k)
48
- LQTB = L1.unsqueeze(1) * QTB # (rank, k)
49
- sketch1 = Q1 @ LQTB # (n, k)
50
-
51
- # skecth A2
52
- vB = v2 @ B
53
- sketch2 = v2.outer(vB)
54
-
55
- return w1 * sketch1 + w2 * sketch2
56
-
57
-
58
- def adanystrom_update(
59
- L1: torch.Tensor,
60
- Q1: torch.Tensor,
61
- v2: torch.Tensor,
62
- w1: float,
63
- w2: float,
64
- oversampling_p: int,
65
- rank: int,
66
- eig_tol: float,
67
- damping: float,
68
- rdamping: float,
69
- orthogonalize_method: OrthogonalizeMethod,
70
-
71
- ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
72
- """computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
73
- where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
74
-
75
- returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
76
-
77
- Args:
78
- L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
79
- Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
80
- v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
81
- w1 (float): weight for A1.
82
- w2 (float): weight for A2.
83
- """
84
- n = Q1.shape[0]
85
- device = Q1.device
86
- dtype = Q1.dtype
87
- l = rank + oversampling_p
88
-
89
- # gaussian test matrix
90
- Omega = torch.randn(n, l, device=device, dtype=dtype)
91
-
92
- # sketch
93
- AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
94
- Q = orthogonalize(AOmega, orthogonalize_method)
95
-
96
- AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
97
- QTAQ = Q.T @ AQ
98
-
99
- W = (QTAQ + QTAQ.T) / 2.0
100
-
101
- # compute new L and Q
102
- try:
103
- L_prime, S = torch_linalg.eigh(W, retry_float64=True)
104
- except torch.linalg.LinAlgError:
105
- return L1, Q1
106
-
107
- L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
108
-
109
- if L_prime is None or S is None:
110
- return L1, Q1
111
-
112
- return L_prime, Q @ S
113
-
114
-
115
- # def adanystrom_update2(
116
- # L1: torch.Tensor,
117
- # Q1: torch.Tensor,
118
- # v2: torch.Tensor,
119
- # w1: float,
120
- # w2: float,
121
- # rank: int,
122
- # ):
123
- # def A_mm(X):
124
- # return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
125
-
126
- # return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
127
-
128
- class AdaNystrom(TensorTransform):
129
- """Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
130
-
131
- Args:
132
- rank (_type_): rank of Nyström approximation.
133
- w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
134
- w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
135
- oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
136
- eig_tol (float, optional):
137
- removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
138
- damping (float, optional):
139
- added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
140
- rdamping (float, optional):
141
- added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
142
- mm_tol (float, optional):
143
- removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
144
- mm_truncate (int | None, optional):
145
- uses top k eigenvalues to compute the update. Defaults to None.
146
- mm_damping (float, optional):
147
- added to eigenvalues when computing the update. Defaults to 1e-4.
148
- mm_rdamping (float, optional):
149
- added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
150
- id_reg (float, optional):
151
- multiplier to identity matrix added to preconditioner before computing update
152
- If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
153
- This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
154
- concat_params (bool, optional):
155
- whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
156
- update_freq (int, optional): update frequency. Defaults to 1.
157
- inner (Chainable | None, optional): inner modules. Defaults to None.
158
- """
159
- def __init__(
160
- self,
161
- rank:int = 100,
162
- beta=0.95,
163
- oversampling: int = 10,
164
- eig_tol: float | None = 1e-32,
165
- damping: float = 0,
166
- rdamping: float = 0,
167
- mm_tol: float = 0,
168
- mm_truncate: int | None = None,
169
- mm_damping: float = 0,
170
- mm_rdamping: float = 0,
171
- id_reg: float | None = None,
172
- orthogonalize_method: OrthogonalizeMethod = 'qr',
173
- eigenbasis_optimizer: LREOptimizerBase | None = None,
174
- orthogonalize_interval: int | None = 100,
175
-
176
- concat_params: bool = True,
177
- update_freq: int = 1,
178
- inner: Chainable | None = None,
179
- ):
180
- defaults = locals().copy()
181
- for k in ["self", "concat_params", "inner", "update_freq"]:
182
- del defaults[k]
183
-
184
- super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
185
-
186
- def single_tensor_update(self, tensor, param, grad, loss, state, setting):
187
- state["step"] = state.get("step", 0) + 1
188
- rank = setting["rank"]
189
- device = tensor.device
190
- dtype = tensor.dtype
191
- beta = setting["beta"]
192
-
193
- try:
194
- if "L" not in state:
195
- # use just tensor and zero L and Q with zero weight
196
-
197
- L, Q = adanystrom_update(
198
- L1=torch.zeros(rank, device=device, dtype=dtype),
199
- Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
200
- v2=tensor.ravel(),
201
- w1=0,
202
- w2=1-beta,
203
- rank=rank,
204
- oversampling_p=setting["oversampling"],
205
- eig_tol=setting["eig_tol"],
206
- damping=setting["damping"],
207
- rdamping=setting["rdamping"],
208
- orthogonalize_method=setting["orthogonalize_method"],
209
- )
210
-
211
- state["L"] = state["L_reg"] = L
212
- state["Q"] = state["Q_reg"] = Q
213
-
214
- else:
215
- L = state["L"]
216
- Q = state["Q"]
217
-
218
- w1 = beta
219
- w2 = 1 - w1
220
-
221
- # compute new factors (this function truncates them)
222
- L_new, Q_new = adanystrom_update(
223
- L1=L,
224
- Q1=Q,
225
- v2=tensor.ravel(),
226
- w1=w1,
227
- w2=w2,
228
- rank=rank,
229
- oversampling_p=setting["oversampling"],
230
- eig_tol=setting["eig_tol"],
231
- damping=setting["damping"],
232
- rdamping=setting["rdamping"],
233
- orthogonalize_method=setting["orthogonalize_method"],
234
- )
235
-
236
- _eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
237
-
238
- except torch.linalg.LinAlgError:
239
- pass
240
-
241
- def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
242
- if "L_reg" not in state:
243
- return tensor.clip(-0.1, 0.1)
244
-
245
- if "eigenbasis_state" not in state:
246
- state["eigenbasis_state"] = {}
247
-
248
- return eigengrad_apply(
249
- tensor=tensor,
250
- L_reg = state["L_reg"],
251
- Q_reg = state["Q_reg"],
252
- beta = setting["beta"],
253
- step = state["step"],
254
- debias = True,
255
- id_reg = setting["id_reg"],
256
- eigenbasis_optimizer = setting["eigenbasis_optimizer"],
257
- eigenbasis_state = state["eigenbasis_state"]
258
- )
@@ -1,142 +0,0 @@
1
- from collections import deque
2
- from typing import Literal
3
-
4
- import torch
5
-
6
- from torchzero.core import Chainable, TensorTransform
7
- from torchzero.linalg import matrix_power_eigh, torch_linalg, orthogonalize, OrthogonalizeMethod, regularize_eigh
8
- from torchzero.utils import TensorList, vec_to_tensors_
9
-
10
-
11
- def update_subspace_preconditioner_(
12
- grad: torch.Tensor, # store grads and basis as vectors for matmul
13
- basis: torch.Tensor, # ndim, k
14
- accumulator_: torch.Tensor, # k, k
15
- beta: float | None,
16
- ):
17
- projected = basis.T @ grad # k
18
- outer = torch.outer(projected, projected)
19
-
20
- if beta is None: accumulator_.add_(outer)
21
- else: accumulator_.lerp_(outer, 1-beta)
22
-
23
- # yeah so I can also run subspace opts in this basis
24
- def apply_subspace_preconditioner(
25
- tensor: torch.Tensor,
26
- basis: torch.Tensor, # ndim, k
27
- accumulator: torch.Tensor,
28
- tol: float,
29
- truncate: int | None,
30
- damping: float,
31
- rdamping: float,
32
- ):
33
- L, Q = torch_linalg.eigh(accumulator, retry_float64=True)
34
- L, Q = regularize_eigh(L=L, Q=Q, truncate=truncate, tol=tol, damping=damping, rdamping=rdamping)
35
-
36
- if L is None or Q is None:
37
- return tensor.clip(-0.1, 0.1)
38
-
39
- preconditioner = (Q * L.rsqrt().unsqueeze(-2)) @ Q.mH
40
-
41
- tensor_projected = basis.T @ tensor # k
42
- update_projected = preconditioner @ tensor_projected # k
43
- return basis @ update_projected # d
44
-
45
-
46
- class CommonDirectionsWhiten(TensorTransform):
47
- """Whitens in subspace spanned by history of gradient differences.
48
-
49
- Args:
50
- beta - for preconditioner itself in the basis.
51
- basis_beta - how much basis is allowed to change.
52
- """
53
-
54
- def __init__(
55
- self,
56
- k: int = 100,
57
- beta: float | None = 0.95,
58
- basis_beta=0.95,
59
- tol: float = 1e-7,
60
- truncate: int | None = None,
61
- damping: float = 1e-4,
62
- rdamping: float = 0,
63
- basis_type: Literal["gradients", "differences"] = "differences",
64
- orthogonalize_method: OrthogonalizeMethod | None = 'newtonschulz',
65
-
66
- concat_params: bool = True,
67
- inner: Chainable | None = None,
68
- ):
69
- defaults = locals().copy()
70
- for key in ["self", "inner", "concat_params"]:
71
- del defaults[key]
72
-
73
- super().__init__(defaults, concat_params=concat_params, inner=inner)
74
-
75
- @torch.no_grad
76
- def single_tensor_update(self, tensor, param, grad, loss, state, setting):
77
- g = tensor.ravel()
78
- k = setting['k']
79
- beta = setting['beta']
80
- basis_beta = setting['basis_beta']
81
- step = state.get("step", 0)
82
- state["step"] = step + 1
83
-
84
- # initialize history
85
- if 'history' not in state:
86
- state['history'] = deque(maxlen=k)
87
- state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
88
- state['basis'] = torch.zeros(g.numel(), k, device=g.device, dtype=g.dtype)
89
-
90
- history: deque = state['history']
91
- accumulator = state['accumulator']
92
- basis = state['basis']
93
- history.append(g)
94
-
95
- # stack history to new basis term, if history isn't full, fill with random vecs
96
- if len(history) < k:
97
- basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
98
- history_basis = torch.stack(tuple(history), -1)
99
- basis_t[:, -len(history):] = history_basis
100
-
101
- else:
102
- basis_t = torch.stack(tuple(history), -1)
103
-
104
- # in this case basis uses differences in gradients except last entry is the gradient
105
- if setting["basis_type"] == "differences":
106
- basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
107
-
108
- # normalize or orthonormalize new basis term
109
- if setting["orthogonalize_method"] is not None:
110
- basis_t = orthogonalize(basis_t, method = setting["orthogonalize_method"])
111
- else:
112
- basis_t = (basis_t - basis_t.mean()) / basis_t.std().clip(min=torch.finfo(g.dtype).tiny * 2)
113
-
114
- # lerp basis
115
- basis.lerp_(basis_t, 1-basis_beta)
116
- basis = basis / (1 - basis_beta ** (step+1)) # correct bias on basis EMA
117
- update_subspace_preconditioner_(g, basis, accumulator, beta)
118
-
119
- @torch.no_grad
120
- def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
121
- g = tensor.ravel()
122
-
123
- basis = state['basis']
124
- accumulator = state['accumulator']
125
- step = state["step"]
126
- accumulator = accumulator / (1 - setting["beta"] ** (step+1)) # correct bias on accumulator EMA
127
-
128
- try:
129
- preconditioned = apply_subspace_preconditioner(
130
- g,
131
- basis,
132
- accumulator,
133
- tol=setting["tol"],
134
- truncate=setting["truncate"],
135
- damping=setting["damping"],
136
- rdamping=setting["rdamping"],
137
- )
138
- except torch.linalg.LinAlgError:
139
- preconditioned = g.clip(-0.1, 0.1)
140
-
141
- return preconditioned.view_as(tensor)
142
-