torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
1
+ tests/test_identical.py,sha256=Y48_1f5WrltmO8a_-x-9Yltz2ZeMh8N8q3MGjOCkJhA,11552
2
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
3
+ tests/test_module_autograd.py,sha256=cncOlxtxmyJQHUd7nL9aWLRAr1kxtlKgVLqP3_qIb2E,21374
4
+ tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
5
+ tests/test_opts.py,sha256=hw7CCw7FD_RJSdiSacyXUSM7DI-_RfP8wJlsz079SNw,44263
6
+ tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
7
+ tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
8
+ torchzero/__init__.py,sha256=nit4KxrRoW6hJDGOy0jkphuawY5gAvPqrYY11Yct6fA,133
9
+ torchzero/core/__init__.py,sha256=h9Ck7XX2XuJUTojU2IMa_2TprXZHbgo748txa3z7-2o,341
10
+ torchzero/core/chain.py,sha256=dtFpxnw8vcbi3EeAANXyPtUmyPyv_VuZrTiPlLRmh7c,1899
11
+ torchzero/core/functional.py,sha256=TSygtyQHDhqf998--hF48yIFr-y3Ycz8arjjR8x1ILU,3156
12
+ torchzero/core/modular.py,sha256=Xpp6jfiKArC3Q42G63I9qj3eWcYt-l7d-EIm-59ADcI,9584
13
+ torchzero/core/module.py,sha256=HfbPfxXxgyBf9wQl7Fpw6B6Ux6UYfvPEmITC64ozb_Q,18012
14
+ torchzero/core/objective.py,sha256=kEIlry7Bxf_zDUoqAIKUTRvvJmCEpn0Ad2crNt18GCc,40005
15
+ torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
16
+ torchzero/core/transform.py,sha256=aJRBtvYjKqD-Ic_AkzeSINYDsTaBAErA-kocEl3PHZw,12244
17
+ torchzero/linalg/__init__.py,sha256=wlry3dbncdsySKk6sSdiRefTcc8dIh4DcA0wFyU1MC8,407
18
+ torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
19
+ torchzero/linalg/eigh.py,sha256=YC8x5NEWWsnc3suCebnTfeb4lVMhy-H8LGOZbGnwd8A,7902
20
+ torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
21
+ torchzero/linalg/linear_operator.py,sha256=mVEOvu6yY7TYhUdmZm1IAc6_pWnTaykKDgZu_-J-atk,16653
22
+ torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
23
+ torchzero/linalg/orthogonalize.py,sha256=Fv6zv1JvS9AVwjiMVed55J8-pEbVZv7vqoEo5g0Zrv0,3270
24
+ torchzero/linalg/qr.py,sha256=KykXhSlye0vhyP5JjX6pkPnheHKLLbAKmDff8Hogxyo,2857
25
+ torchzero/linalg/solve.py,sha256=kING1WCioof8_EKgHeyr53dlft_9KtlJnwOWega3DnA,14355
26
+ torchzero/linalg/svd.py,sha256=jmunSxM-twR5VCUGI_gmV3j7QxMJIe1aBoBlJf5i2fo,1432
27
+ torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
28
+ torchzero/modules/__init__.py,sha256=dsOalCw-OVkD8rhpQdcODc3Hsd_sQ2_2xVC-J8mlSuk,632
29
+ torchzero/modules/opt_utils.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
30
+ torchzero/modules/adaptive/__init__.py,sha256=X8w2Dal3k0WpLQN-WolnWBBgUyIiZF5RnqBlN0dcAYw,1081
31
+ torchzero/modules/adaptive/adagrad.py,sha256=hMT-Al-vtD6tzPUpQ79LCNko97D7rJN5ji9JOfBqR3k,12015
32
+ torchzero/modules/adaptive/adahessian.py,sha256=ucf8loS_lU9VjCb_M42WwXESjPJ_KFChLGkIMFWXO5o,8734
33
+ torchzero/modules/adaptive/adam.py,sha256=Okm7Sc9fMArQAZ7Ph4Etq68uL-IXKY4YNqHWpTzPoTY,3767
34
+ torchzero/modules/adaptive/adan.py,sha256=965tBUwKy6uDiY2la6fVcGcsvGMs90Zg-ZHPtozJGe4,4110
35
+ torchzero/modules/adaptive/adaptive_heavyball.py,sha256=iDiZqke6z6FOR9mhoHMLMm7jvxjzHIQANTe0FBwNj1Q,2230
36
+ torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
37
+ torchzero/modules/adaptive/esgd.py,sha256=gnah-7zk_fMsn7yIWivqDgnaaSdDFXpxg33ywF6TMZg,6173
38
+ torchzero/modules/adaptive/ggt.py,sha256=eYCeV3GArdLv9WuWeim0V3CHJYl3FVKtrtsGshkqwWg,6608
39
+ torchzero/modules/adaptive/lion.py,sha256=H3aI2qnrMtmkvXcoddzjjxdkoD5cq_QwIkLmd_bVPso,1085
40
+ torchzero/modules/adaptive/lre_optimizers.py,sha256=AwWUIwnBrozR2HFYLfJnMCBHAWWMKzkS63xFKstRgc0,9760
41
+ torchzero/modules/adaptive/mars.py,sha256=w-cK-1tFuR74SY01xS5jsg1b9qs3l8eOptGrUyQ2m80,2261
42
+ torchzero/modules/adaptive/matrix_momentum.py,sha256=YefF2k746ke7qiiabdhCPCUFB1_fRddAfGCyIOwV3Ok,6789
43
+ torchzero/modules/adaptive/msam.py,sha256=nqwjuhBMX2UO-omUIeOcD5ti6PIKfKs-RVCn7ourkKA,6946
44
+ torchzero/modules/adaptive/muon.py,sha256=jQ6jlfM4vVRidGJ7FrLtgPnZeuIfW_zU72o7LvOKqh8,8023
45
+ torchzero/modules/adaptive/natural_gradient.py,sha256=8UzacvvIMbYVVE2q0HQ9DLLHYlm1eu6cAiRsOv5XRzQ,7078
46
+ torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
47
+ torchzero/modules/adaptive/rmsprop.py,sha256=qWVkRmUQ3dui9yBVYtAEll7OlXZDKNT_m70FakTOrTY,4529
48
+ torchzero/modules/adaptive/rprop.py,sha256=a4_UkWse5u2JFAEIlxQqDBUwvUfxh1kNs2ZIhtccnWE,11540
49
+ torchzero/modules/adaptive/sam.py,sha256=CTMCqaH9s5EmKQyj1GpqSeTO1weyfsNWPYFN1xaSm_o,5709
50
+ torchzero/modules/adaptive/shampoo.py,sha256=C_Mo7UFQtDxW4McWJjT731FNAp3g9MqF0Hka54Yi3xQ,9847
51
+ torchzero/modules/adaptive/soap.py,sha256=hz2N6-jUSWU93RNViIS1c-Ue2uKmQx6BxyYg6mEa2fo,12408
52
+ torchzero/modules/adaptive/sophia_h.py,sha256=O_izgGlUgUlpH3Oi5PdCKTyxus4yO1PaJUFhGXuGG9k,7063
53
+ torchzero/modules/adaptive/psgd/__init__.py,sha256=g73mAkWEutwU6jzjiwdbYk5Yxgs4i6QVWefFKkm8cDw,223
54
+ torchzero/modules/adaptive/psgd/_psgd_utils.py,sha256=YtwbUKyVWITZPmpwCBJBC42XQP9HcxNx_znEaIv3hsI,1096
55
+ torchzero/modules/adaptive/psgd/psgd.py,sha256=CVDJI3fcdPCpLCWFY_pvd_COaZMB3nrYOYDdqQcEaOY,73340
56
+ torchzero/modules/adaptive/psgd/psgd_dense_newton.py,sha256=62R9rkWRynrvmTzb12zi5oQoWWdAWz4i1EC4FXa4zv4,7094
57
+ torchzero/modules/adaptive/psgd/psgd_kron_newton.py,sha256=oH-oI1pvbR-z6H6ma1O2GG0nfjx6NWzUHYKY_FAvRpg,8381
58
+ torchzero/modules/adaptive/psgd/psgd_kron_whiten.py,sha256=vmhkY6cKaRE5qzy_4tUkIJp6qC3L6ESZMuiU_ih5tR4,7299
59
+ torchzero/modules/adaptive/psgd/psgd_lra_newton.py,sha256=JL8JmqHgcFqfkX7VeD3sRvNj0xeCuDTHxjNyQ_HigBw,4709
60
+ torchzero/modules/adaptive/psgd/psgd_lra_whiten.py,sha256=SaNYtE4_2tV29CbVaTHi8A6RxmhoMaucF5NoMRg6QaA,4197
61
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
62
+ torchzero/modules/clipping/clipping.py,sha256=C2dMt0rpuiLMsKq2EWi8qhISSxfCU0nKKGgjWEk2Yxc,14198
63
+ torchzero/modules/clipping/ema_clipping.py,sha256=D4NgXzXYMjK_SKQU3rVoOKzaCd9igGQg_7sXiGMgMqI,6750
64
+ torchzero/modules/clipping/growth_clipping.py,sha256=I1nk5xXBjk0BzWYzMC58LZHouY44myZNIUjM-duv7zc,6508
65
+ torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
66
+ torchzero/modules/conjugate_gradient/cg.py,sha256=fcmP77_v_RPpb0sDV2B_90FvFY67FdJt54KHdccY5YU,14540
67
+ torchzero/modules/experimental/__init__.py,sha256=YbBrWu2vkXHiBcDXmim-Yte4ZxfmQCs_0fCeIArvtnM,942
68
+ torchzero/modules/experimental/adanystrom.py,sha256=fUWPxxi1aJhWme_d31dBG0XxEZY1hJr6AEiFHdFDxCQ,8970
69
+ torchzero/modules/experimental/common_directions_whiten.py,sha256=R_1fQKlvMD99oFrflJLgxl6ObV8jyPc7-NxAUFQeoYA,4941
70
+ torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
71
+ torchzero/modules/experimental/cubic_adam.py,sha256=RhcHajUfUAcXZDks0X0doR18YtMItQYPmxuEihud4bo,5137
72
+ torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
73
+ torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
74
+ torchzero/modules/experimental/eigen_sr1.py,sha256=rCcWVplTWQh91xpgDap35CGEex41C19irUfDlq9lviU,6865
75
+ torchzero/modules/experimental/eigengrad.py,sha256=UPuyo-OmCmu3XLAPclIfsnMN4qcNwX83m7S_55syukA,8455
76
+ torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
77
+ torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
78
+ torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
79
+ torchzero/modules/experimental/l_infinity.py,sha256=zu3aRLmZkU4LxfyToqjU8re0BLMUd6gl16_9AffIfcw,4752
80
+ torchzero/modules/experimental/matrix_nag.py,sha256=fjj07uZbmYvy4AgrtqvX_WLwvK7HBpKx09Q3BK3L0jo,4451
81
+ torchzero/modules/experimental/newton_solver.py,sha256=aHZh8EA-QQop3iGz7Ge37KTNgEnxOr04PVBmIiBxmiQ,4073
82
+ torchzero/modules/experimental/newtonnewton.py,sha256=TYUuQwHu8bom08czU9lP7MQq5qFBq_JYZTH_Wmm4g-o,3269
83
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
84
+ torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
85
+ torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
86
+ torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
87
+ torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
88
+ torchzero/modules/grad_approximation/fdm.py,sha256=hq7U8UkzCfc7z0J1ZmZo9xOLzHHY0uRjebcwZQrBCzA,4376
89
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
90
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
91
+ torchzero/modules/grad_approximation/rfdm.py,sha256=-5zqMB98YNNa1aQXXtf6UNGSJxySO7mn1NksWyPzp3o,19607
92
+ torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
93
+ torchzero/modules/least_squares/gn.py,sha256=3RQ_7e35Ql9uVUUPi34nef9eQNeZ09fldi964V61Tgg,7889
94
+ torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
95
+ torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
96
+ torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
97
+ torchzero/modules/line_search/backtracking.py,sha256=xomhfRmOquGn_liGLS-5PKP89YpAtzFbM3vPTjz0T5A,9051
98
+ torchzero/modules/line_search/interpolation.py,sha256=tHXlZD1MgfYaymhXY75k9CocltwSYDYZg6ENDCEUiss,4942
99
+ torchzero/modules/line_search/line_search.py,sha256=Qcil-WbXnOpdi_DqzvRzXnInedUYNXDglsnYqfGYHdk,12925
100
+ torchzero/modules/line_search/scipy.py,sha256=xQ80h9cSyF4Iorq_1NoJglu_Bx4_KeojulBIxvwU6gQ,2836
101
+ torchzero/modules/line_search/strong_wolfe.py,sha256=9jGjxebuXHbl8wEFpvV0s4mMX4JMx2aVHqFn_8qpX6g,14979
102
+ torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
103
+ torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
104
+ torchzero/modules/misc/escape.py,sha256=c_OMf2jQ7MbxkrXWNmgIpZrBe28N9f89tnzuCQ3fu3A,1930
105
+ torchzero/modules/misc/gradient_accumulation.py,sha256=Xzjt_ulm6Z3mpmtagoUqoefhoeSDVnmX__tVbcI_RQE,2271
106
+ torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
107
+ torchzero/modules/misc/misc.py,sha256=f-3qxBq1KYI3iGYJXzv1cHEJHc0ScEp-vCLCgiaEgJQ,15002
108
+ torchzero/modules/misc/multistep.py,sha256=twdE-lU9Wa0b_uquH9kZ-1OwP0gqWfFMJkdjVWJRwe4,6599
109
+ torchzero/modules/misc/regularization.py,sha256=MCd_tnBYfFnx0b3sM1vHNQ_WbTVfo7l8pxmxGVgWcc0,5935
110
+ torchzero/modules/misc/split.py,sha256=rmi9PgMgiqddrr8fY8Dbdcl2dgwTn9YBAve_bg5Zd08,4288
111
+ torchzero/modules/misc/switch.py,sha256=_ycuD23gR0ZvIUmX3feYBr0_WTX22Pfhu3whpiSCMv4,3678
112
+ torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
113
+ torchzero/modules/momentum/averaging.py,sha256=Q6WLwCJwgNY96YIfQXWpsX-2kDR7n0IOMDfZMvNVc9U,3035
114
+ torchzero/modules/momentum/cautious.py,sha256=1hD2H08OQaNZG52sheRADBsuf9uJsaoLV4n-UVGUH3Y,8379
115
+ torchzero/modules/momentum/momentum.py,sha256=MPHd4TU1bSlEKLGfueNdmaZ13V5J1suW6agBc3SvrTs,4389
116
+ torchzero/modules/ops/__init__.py,sha256=xUYzWWLlSwaT8sw3dWywkALqI6YGCZgptWQJVy83HhM,1249
117
+ torchzero/modules/ops/accumulate.py,sha256=f-Uutg7gNFRobTc5YI9JlfFiSacXmg0gDhIwQNwZSZg,3439
118
+ torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
119
+ torchzero/modules/ops/higher_level.py,sha256=cUh-908S0GWVGekmUN5c_Vx0HP3P2tQoKN3COQM5TaQ,8965
120
+ torchzero/modules/ops/multi.py,sha256=WzNK07_wL7z0Gb2pmv5a15Oss6tW9IG79x1c4ZPmOqQ,8643
121
+ torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
122
+ torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
123
+ torchzero/modules/ops/utility.py,sha256=UkR3BfN_NBsZe78oERnknf6lmgeGNiEbtfUDNgU0YoQ,4423
124
+ torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
125
+ torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
126
+ torchzero/modules/projections/galore.py,sha256=70k30-2RJ3ncTNZpjxBhyYq1yJVFGfS2YUvU2trIK4o,258
127
+ torchzero/modules/projections/projection.py,sha256=FK1gURrOdSaKLVsTHgSeokoTfCGTLYGY9WB4slpclr4,14485
128
+ torchzero/modules/quasi_newton/__init__.py,sha256=o0N_XvzY5qGOsRYXU2bNdF8_zRb5ZZxaL6inTBl2Sk4,536
129
+ torchzero/modules/quasi_newton/damping.py,sha256=bz2GFjasb-2B6bo8dn3-UJ3ddkiUbiJtaQBLqsixv_E,2806
130
+ torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=q6qCYQzWZxqg5S55mmI6Kf0uRoIr_dtSWrFd-2jpe6o,6929
131
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=ftuY65N7_zJtByQ_JM_sRgxKqW7vNktuMb7qFqDGF8M,11203
132
+ torchzero/modules/quasi_newton/lsr1.py,sha256=5gZKoFBVCLg3XeXOQTALAkk3XvU2wMdjDi5J_B7iJLs,8515
133
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=vjgYWInavb81twwQQWPfK7gj1aT4ljJNEMSnNsDOM9c,46015
134
+ torchzero/modules/quasi_newton/sg2.py,sha256=-d6dVqlt6Xf03tN1a26Jx9YK2kqOS_snuMuIUT3iYqQ,4164
135
+ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
136
+ torchzero/modules/restarts/restars.py,sha256=gcRZ8VHGg60cFVzsk0TWa6-EXoqEFbEeP1p7fs2Av0Q,9348
137
+ torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
138
+ torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
139
+ torchzero/modules/second_order/inm.py,sha256=OddoZHQfSuFnlx_7Zj2qiVcC2A_9yMVn_0Gy1A7hNAg,3420
140
+ torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
141
+ torchzero/modules/second_order/newton.py,sha256=QcLXsglvf4zJEwR4cldsGVZCABQtxb6U5qVmU3spN_A,11061
142
+ torchzero/modules/second_order/newton_cg.py,sha256=k8G8CSmeIQZObkWVURFnbF_4g2UvJiwh3xToxn7sFJE,14816
143
+ torchzero/modules/second_order/nystrom.py,sha256=WQFfJj0DOfWXyyx36C54m0WqZPIvTTK7n8U7khLhGLg,13359
144
+ torchzero/modules/second_order/rsn.py,sha256=9s-JyJNNeDlIFv8YVGn7y8DGPnP93WJEjpUQXehX3uY,9980
145
+ torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
146
+ torchzero/modules/smoothing/laplacian.py,sha256=1cewdvnneKn51bbIBqKij0bkveKE7wOYCZ-aGlqzK5M,5201
147
+ torchzero/modules/smoothing/sampling.py,sha256=bCH7wlTYZ_vtKUKSkI6znORxQ5Z6DGcpo10F-GYvFlE,12880
148
+ torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
149
+ torchzero/modules/step_size/adaptive.py,sha256=TYLESwhcs6QCR9AvQWtMuU-XJRDmNEmSCnPdBFwijiY,14679
150
+ torchzero/modules/step_size/lr.py,sha256=Pi8B5hhJWq_NdVMdR4reZhqu1jjC2S5TWTzkMKAEMcM,5931
151
+ torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
152
+ torchzero/modules/termination/termination.py,sha256=lJXLmtA84JoK_QHhjBxfW9lkxGIxqg7cE7d755MeduE,6905
153
+ torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
154
+ torchzero/modules/trust_region/cubic_regularization.py,sha256=QJjLRkfERvOzV5dTdyMvUzdVqOMPyKnJ7FC_d5cR5V8,6711
155
+ torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
156
+ torchzero/modules/trust_region/levenberg_marquardt.py,sha256=-qbeEW3qRKou48bBdZ-u4Nv43TMt475XV6P_aWfxtqE,5039
157
+ torchzero/modules/trust_region/trust_cg.py,sha256=X9rCJQWvptjZVH2H16iekvAYmleKQAYZKRKC3V0JjFY,4455
158
+ torchzero/modules/trust_region/trust_region.py,sha256=oXMNIvboz0R_1J0Gfd4IvbnwZFl32csNVv-lTYGB0zk,12913
159
+ torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
160
+ torchzero/modules/variance_reduction/svrg.py,sha256=hXEJ0PUYSksHV0ws3t3cE_4MUTTEn1Htu37iZdDdJCs,8746
161
+ torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
162
+ torchzero/modules/weight_decay/reinit.py,sha256=Ngb3AUfJU-aE-lUSSDmf6rlx2zJUgJ6JrQ3U1DpuEh8,3311
163
+ torchzero/modules/weight_decay/weight_decay.py,sha256=dC0JidMlUU4WzXRTSwQT0Gp45UdKvZLu6JnXveJTR4c,5362
164
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
165
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=T-DnI_pbNpt1BN9qvW6KOVGly1RixsdSpIC8FtmmDlY,4701
166
+ torchzero/modules/zeroth_order/__init__.py,sha256=1ADUiOHVHzvIP4TpH7_ILmeW2heidfikbf6d5g_1RzY,18
167
+ torchzero/modules/zeroth_order/cd.py,sha256=6mHQC_pInVQ6HmhW5yAKwkPjBVwhUg8uX-J3GEIb9wQ,5018
168
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
169
+ torchzero/optim/mbs.py,sha256=mUbl9yl7cxNzmp7J6CXkkjcUd3KBF5zPdDjuiDgJxVU,10970
170
+ torchzero/optim/root.py,sha256=MnXytsgTjDKhYw3UKux1O-g4vuHsFQu2Emsq5zphu_8,2308
171
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
172
+ torchzero/optim/utility/split.py,sha256=uNnKA40OiXUlr-vlHuU_rLEUfXQXgvr6Cd9yGzjWJiA,1702
173
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
174
+ torchzero/optim/wrappers/directsearch.py,sha256=7f1Zy9ZxftgIFxJ9_1snyMjp8ypB4ULyddOqgxdwkEg,9806
175
+ torchzero/optim/wrappers/fcmaes.py,sha256=Bobx87TWinICd3SyNaKhoplpNbrLdD5vR1J1TZw69Cs,3437
176
+ torchzero/optim/wrappers/mads.py,sha256=Gm1msCsVNJDq-WS913peZ0TzXU_7wgP_sPmMfEK0VIY,2326
177
+ torchzero/optim/wrappers/moors.py,sha256=AjEeovPEvxEGJwbhOaOCrCIxPgwOHilkj-vWb-88Y_0,2073
178
+ torchzero/optim/wrappers/nevergrad.py,sha256=wFjIPOSffdRkxdtbwBMBw1_ubP9kS2z4nHV6YtqUrvw,4442
179
+ torchzero/optim/wrappers/nlopt.py,sha256=FzU_3VuNN0_ngT_FlQs7ZOL_RZcfeFtJNmklqvyzTkw,8694
180
+ torchzero/optim/wrappers/optuna.py,sha256=Gvs8tGGIV-mJZ1nMrLN3Ju3am-lbabLd_-mVRO6JlkY,2252
181
+ torchzero/optim/wrappers/pybobyqa.py,sha256=1eH3HW2HVyNneZ5vB9ch1hb0OdoYrnk4pJmQVRIOAw4,6309
182
+ torchzero/optim/wrappers/wrapper.py,sha256=eEZnqKRc2cSFVVXNyOGGReMX3MvY0eoHSRjFsQ8pCSw,4380
183
+ torchzero/optim/wrappers/scipy/__init__.py,sha256=lAciW9WTZ8JZ9ZFunyi78cOa6URhTL-rIjk1AdqbrxM,262
184
+ torchzero/optim/wrappers/scipy/basin_hopping.py,sha256=11FuVLYDJxbyxW-44UO-JRef7MErGFfM_PqqqADA8RQ,4435
185
+ torchzero/optim/wrappers/scipy/brute.py,sha256=MD5fWcUo0-COrq5_OKM8X5SSMZaN--4ThW8mdISLX00,1213
186
+ torchzero/optim/wrappers/scipy/differential_evolution.py,sha256=GBJSh0K5ueZX3fY9vxTntQnuIoLDjOQyOj6TMCmTGBg,2669
187
+ torchzero/optim/wrappers/scipy/direct.py,sha256=lTkH8aRY7sWkhrtQ8Z-m8fSdhTpsLxOIQSJprSqThFA,1821
188
+ torchzero/optim/wrappers/scipy/dual_annealing.py,sha256=mIdXy8F3YEfrQCXNxps8T33yRDuqJxaaIjzB0wlKo90,4117
189
+ torchzero/optim/wrappers/scipy/experimental.py,sha256=-_82LtrDVg0qRg6E1coXugsCGGcVNXgxv6m5Anaaw24,4833
190
+ torchzero/optim/wrappers/scipy/minimize.py,sha256=6O2Js2ecGZkNE3mqXJR9Rqp5xlh3KIVYBMF82k5hrws,6491
191
+ torchzero/optim/wrappers/scipy/sgho.py,sha256=mKG9uG-zaEzB6mbiZ25FnXdRcfvGYYOHapuHniyJ8r8,4054
192
+ torchzero/utils/__init__.py,sha256=tpy9ti5Ub5d1zQPFODZ4PjmFKNzoZTd-NByd_snlYtk,761
193
+ torchzero/utils/compile.py,sha256=dY9ioWQvJt71HQN_z4kXX4KgtJ0xOW0MKNlZpHD6118,5130
194
+ torchzero/utils/derivatives.py,sha256=Sc20EH2v2czjH9Z8UChvq0EaYtvOEJKEYOk3fVb0Z6M,18085
195
+ torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
196
+ torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
197
+ torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
198
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
199
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
200
+ torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
201
+ torchzero/utils/tensorlist.py,sha256=4rN8gm967pPmtO5kotXqIX7Mal0ps-IHkGBybfeWY4M,56357
202
+ torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
203
+ torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
204
+ torchzero/utils/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
205
+ torchzero/utils/benchmarks/logistic.py,sha256=RHsjHEWkPqaag0kt3wfmdddh4DhftcyW9r70tj9OGp4,4382
206
+ torchzero-0.4.1.dist-info/METADATA,sha256=hB0rFqXnaRbwVkFRwTwjXpKnIFLi8MBvLXbgXTuUGWk,564
207
+ torchzero-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
208
+ torchzero-0.4.1.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
209
+ torchzero-0.4.1.dist-info/RECORD,,
@@ -1,241 +0,0 @@
1
- from collections import deque
2
- from typing import Literal, Any
3
- import warnings
4
-
5
- import torch
6
- from ...core import Chainable, TensorTransform
7
- from ...linalg import torch_linalg
8
-
9
- def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, tol):
10
- """returns U ``(ndim, rank)``, L ``(rank, )``"""
11
- if isinstance(history, torch.Tensor):
12
- M = history
13
- else:
14
- M = torch.stack(tuple(history), dim=1)# / len(history)
15
-
16
- MTM = M.T @ M
17
- if damping != 0:
18
- MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
19
-
20
- try:
21
- L, Q = torch_linalg.eigh(MTM, retry_float64=True)
22
-
23
- # truncate to top n largest eigenvalues
24
- if truncate is not None and truncate > 0:
25
- # L is ordered in ascending order
26
- L = L[-truncate:]
27
- Q = Q[:, -truncate:]
28
-
29
- # remove small eigenvalues relative to largest
30
- L_max = L.amax()
31
- indices = L > tol * L_max
32
- if indices.any():
33
- L = L[indices]
34
- Q = Q[:, indices]
35
-
36
- U = (M @ Q) * L.rsqrt()
37
-
38
- if rdamping != 0:
39
- L.add_(rdamping * L_max)
40
-
41
- return U, L
42
-
43
- except torch.linalg.LinAlgError:
44
- return None, None
45
-
46
- def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor, exp_avg_proj: torch.Tensor | None, beta:float):
47
- z = U.T @ g
48
-
49
- if beta != 0:
50
- if exp_avg_proj is None: exp_avg_proj = torch.zeros_like(z)
51
- exp_avg_proj.lerp_(z, weight=1-beta)
52
- z = exp_avg_proj
53
-
54
- return (U * L.rsqrt()) @ z, exp_avg_proj
55
-
56
- def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
57
- if value is None: return
58
- if (key not in state_) or (beta is None): state_[key] = value
59
- else:
60
- if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
61
- else: state_[key].lerp_(value, 1-beta)
62
-
63
- class LMAdagrad(TensorTransform):
64
- """
65
- Limited-memory full matrix Adagrad.
66
-
67
- The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
68
- But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
69
-
70
- This is equivalent to full-matrix Adagrad on recent gradients.
71
-
72
- Args:
73
- history_size (int, optional): number of past gradients to store. Defaults to 10.
74
- beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
75
- update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
76
- damping (float, optional): damping value. Defaults to 1e-4.
77
- rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
78
- rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
79
- truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
80
- tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
81
- order (int, optional):
82
- order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
83
- U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
84
- L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
85
- concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
86
- inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
87
-
88
- ## Examples:
89
-
90
- Limited-memory Adagrad
91
-
92
- ```python
93
- optimizer = tz.Modular(
94
- model.parameters(),
95
- tz.m.LMAdagrad(),
96
- tz.m.LR(0.1)
97
- )
98
- ```
99
- Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
100
-
101
- ```python
102
- optimizer = tz.Modular(
103
- model.parameters(),
104
- tz.m.LMAdagrad(inner=tz.m.EMA()),
105
- tz.m.Debias(0.9, 0.999),
106
- tz.m.LR(0.01)
107
- )
108
- ```
109
-
110
- Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
111
-
112
- ```python
113
- optimizer = tz.Modular(
114
- model.parameters(),
115
- tz.m.LMAdagrad(inner=tz.m.EMA()),
116
- tz.m.Debias(0.9, 0.999),
117
- tz.m.ClipNormByEMA(max_ema_growth=1.2),
118
- tz.m.LR(0.01)
119
- )
120
- ```
121
- Reference:
122
- Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
123
- """
124
-
125
- def __init__(
126
- self,
127
- history_size: int = 100,
128
- beta: float = 0.0,
129
- update_freq: int = 1,
130
- damping: float = 1e-4,
131
- rdamping: float = 0,
132
- truncate: int | None = None,
133
- tol: float = 1e-7,
134
- order: int = 1,
135
- U_beta: float | None = None,
136
- L_beta: float | None = None,
137
- concat_params: bool = True,
138
-
139
- inner: Chainable | None = None,
140
- U_tfm: Chainable | None = None,
141
- L_tfm: Chainable | None = None,
142
- ):
143
- defaults = locals().copy()
144
- del defaults['self'], defaults['inner'], defaults['concat_params'], defaults["U_tfm"], defaults["L_tfm"]
145
-
146
- super().__init__(defaults, concat_params=concat_params, inner=inner)
147
-
148
- self.set_child("U", U_tfm)
149
- self.set_child("L", L_tfm)
150
-
151
-
152
- @torch.no_grad
153
- def single_tensor_update(self, tensor, param, grad, loss, state, setting):
154
- order = setting['order']
155
- history_size = setting['history_size']
156
- update_freq = setting['update_freq']
157
- U_beta = setting['U_beta']
158
- L_beta = setting['L_beta']
159
-
160
- if 'history' not in state: state['history'] = deque(maxlen=history_size)
161
- history = state['history']
162
-
163
- if order == 1:
164
- t = tensor.clone().view(-1)
165
- history.append(t)
166
- else:
167
-
168
- # if order=2, history is of gradient differences, order 3 is differences between differences, etc
169
- # scaled by parameter differences
170
- cur_p = param.clone()
171
- cur_g = tensor.clone()
172
- eps = torch.finfo(cur_p.dtype).tiny * 2
173
- for i in range(1, order):
174
- if f'prev_g_{i}' not in state:
175
- state[f'prev_p_{i}'] = cur_p
176
- state[f'prev_g_{i}'] = cur_g
177
- break
178
-
179
- s = cur_p - state[f'prev_p_{i}']
180
- y = cur_g - state[f'prev_g_{i}']
181
- state[f'prev_p_{i}'] = cur_p
182
- state[f'prev_g_{i}'] = cur_g
183
- cur_p = s
184
- cur_g = y
185
-
186
- if i == order - 1:
187
- cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
188
- history.append(cur_g.view(-1))
189
-
190
- step = state.get('step', 0)
191
- if step % update_freq == 0 and len(history) != 0:
192
-
193
- # if maintaining momentum, unproject exp_avg before updating factors and reproject
194
- exp_avg_proj = state.get("exp_avg_proj", None)
195
- exp_avg = None
196
- if exp_avg_proj is not None and "U" in state:
197
- exp_avg = state["U"] @ exp_avg_proj
198
-
199
- # update factors
200
- U, L = lm_adagrad_update(
201
- history,
202
- damping=setting["damping"],
203
- rdamping=setting["rdamping"],
204
- truncate=setting["truncate"],
205
- tol=setting["tol"],
206
- )
207
- maybe_lerp_(state, U_beta, 'U', U)
208
- maybe_lerp_(state, L_beta, 'L', L)
209
-
210
- # re-project exp_avg with new factors
211
- if U is not None and exp_avg_proj is not None:
212
- assert exp_avg is not None
213
- state["exp_avg_proj"] = U.T @ exp_avg
214
-
215
-
216
- if len(history) != 0:
217
- state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
218
-
219
- @torch.no_grad
220
- def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
221
- U = state.get('U', None)
222
- if U is None:
223
- # make a conservative step to avoid issues due to different GD scaling
224
- return tensor.clip_(-0.1, 0.1)
225
-
226
- # -------------------------------- transforms -------------------------------- #
227
- L = state['L']
228
- if "L" in self.children:
229
- if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
230
- L = self.inner_step_tensors("L", [L], clone=True)[0]
231
-
232
- if "U" in self.children:
233
- if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
234
- U = self.inner_step_tensors("U", [U], clone=True)[0]
235
-
236
- # ------------------------------- precondition ------------------------------- #
237
- g = tensor.view(-1)
238
- exp_avg_proj = state.get("exp_avg_proj", None)
239
- update, state["exp_avg_proj"] = lm_adagrad_apply(g, U, L, exp_avg_proj, beta=setting["beta"])
240
- return update.view_as(tensor)
241
-