torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
1
+ tests/test_identical.py,sha256=1_Rv-nwoVL1YcXPDA_DprjbT4jkvr0apUPbeQpebMUI,11508
2
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
3
+ tests/test_module_autograd.py,sha256=cncOlxtxmyJQHUd7nL9aWLRAr1kxtlKgVLqP3_qIb2E,21374
4
+ tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
5
+ tests/test_opts.py,sha256=2KqkvZhisKGmW2wm0d0K1gHjD3IU4V3Z0-6iNVEZ43M,43900
6
+ tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
7
+ tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
8
+ torchzero/__init__.py,sha256=iZFYBgpDDk4d-vunJR1TFqBAHKDAit_SHtepO0RRrIo,131
9
+ torchzero/core/__init__.py,sha256=b6h0FXF9CQgXgaQY6rz_sjjs8zswLh4PZu5FQQqlLUo,339
10
+ torchzero/core/chain.py,sha256=dtFpxnw8vcbi3EeAANXyPtUmyPyv_VuZrTiPlLRmh7c,1899
11
+ torchzero/core/functional.py,sha256=X3igEP-F2_A6gDza8Gc8l3RcjMuuI4TEwmWA6e6NhKo,3154
12
+ torchzero/core/modular.py,sha256=VGe9PHxOBxnJmMx3_JWFVqq5__xmxqXgQMrfbVJGVa4,9574
13
+ torchzero/core/module.py,sha256=E08ZJnZGNlnojURk8Ztm7w-7Q1Qzq9aEI6NwuvC7OBA,18008
14
+ torchzero/core/objective.py,sha256=6GYQrUDmjufvBJ9JByqnMXFpnjrZY2-mVDnlXRot_HU,39978
15
+ torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
16
+ torchzero/core/transform.py,sha256=4JBpzDPBoKk3QWKm2Z4W6r49rh0UxdsoQjGMLbCmcFs,12245
17
+ torchzero/linalg/__init__.py,sha256=K-9kRBysUd9HEeGfFAu5IbvTvm4xmRpQSu5k_EotYQE,347
18
+ torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
19
+ torchzero/linalg/eigh.py,sha256=Dz3hz10u84mOEYKBp902byzYvn2b6kiDFBCxYzXhIWo,1409
20
+ torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
21
+ torchzero/linalg/linear_operator.py,sha256=mVEOvu6yY7TYhUdmZm1IAc6_pWnTaykKDgZu_-J-atk,16653
22
+ torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
23
+ torchzero/linalg/orthogonalize.py,sha256=RVNoBRu4pXCPAYDV9EFIqsrwePNObKvavIdVTl-1eTU,3513
24
+ torchzero/linalg/qr.py,sha256=v5jNT8BTwQGitdLA8xqOSxmK3fUBHksuvYBeRcscwfo,2541
25
+ torchzero/linalg/solve.py,sha256=Jd7pgEEjxe1h5YQxKOhR8uwCjqCK2f1PY2KEHHUzW84,14496
26
+ torchzero/linalg/svd.py,sha256=lvR8FpLk0H795i6DS4BtBZ1hYpm364JJOrGhOX62pEQ,694
27
+ torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
28
+ torchzero/modules/__init__.py,sha256=k79l7dMEmfxvikxfG8iUinsvtkngErSgKhHV3uqzLyc,587
29
+ torchzero/modules/functional.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
30
+ torchzero/modules/adaptive/__init__.py,sha256=pZVcpL3K3GqpfOHAqcQ6YV0EMib2587NY7FhESU671E,943
31
+ torchzero/modules/adaptive/adagrad.py,sha256=AO7-DqibgfTDXDX11F0sjvfrm668ReMOQCG7HBT8TNY,12031
32
+ torchzero/modules/adaptive/adahessian.py,sha256=2-7n412iK3aN0TvkxFao0nRNLJqlzxrlsfHsNiDkTg8,8730
33
+ torchzero/modules/adaptive/adam.py,sha256=26CB7ja80XKjySIJ8BydnwVAr8qZyV68YFN-CI4el2g,3768
34
+ torchzero/modules/adaptive/adan.py,sha256=XQuVkD1y8a0jmu-A7eHw5d8a6g6Bt0dFOLhRD6pggIg,4108
35
+ torchzero/modules/adaptive/adaptive_heavyball.py,sha256=IguOBiYkBuCPacCFfgFP6oWF9lOS1_931h97PG1N13s,2247
36
+ torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
37
+ torchzero/modules/adaptive/esgd.py,sha256=gQXUCBopPp5KVi2_nwt5NqoyQe6JOyoxcGbKp_OYma0,6169
38
+ torchzero/modules/adaptive/lion.py,sha256=g8JFR_iueSxdSqAmDH4QLgeJ7cRTfarb-Te6zVobzQs,1050
39
+ torchzero/modules/adaptive/lmadagrad.py,sha256=plE1FaB7JN9QqizwGtjzNRNDpxp_qqw1Z6ULUl4yzmA,9086
40
+ torchzero/modules/adaptive/mars.py,sha256=_K2Agi5nC-H-jY0vbBwIMTtvbKUXu341X4YhDpCGQHg,2257
41
+ torchzero/modules/adaptive/matrix_momentum.py,sha256=KBz6fCFVvLH_a99wfxUIvvtBWqm7xq3rToW2sz2vKbU,6790
42
+ torchzero/modules/adaptive/msam.py,sha256=GAOeAQojx4k-cGrB-nN15YQBp_AJknPj4jStgGabcDk,6941
43
+ torchzero/modules/adaptive/muon.py,sha256=CHclVyM7XLEVHjBGPI3a8Wg2S_xcOclr2IOQ4HstZOE,7728
44
+ torchzero/modules/adaptive/natural_gradient.py,sha256=N9pYJcm0i7o-GlQIxs-oHRyhzJxQqtAiI_cIwKovpOc,6582
45
+ torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
46
+ torchzero/modules/adaptive/rmsprop.py,sha256=qWVkRmUQ3dui9yBVYtAEll7OlXZDKNT_m70FakTOrTY,4529
47
+ torchzero/modules/adaptive/rprop.py,sha256=HEZaZXo50HX1W7-X_e5Qhi2wk_d9wOFi1_BrKBvE6mo,11536
48
+ torchzero/modules/adaptive/sam.py,sha256=PdcmxxgtTqGtbkHEA6d6JFy_cIqFVVA2XNxIcXQxvzc,5701
49
+ torchzero/modules/adaptive/shampoo.py,sha256=ld_Nwoh2Hemt5QVB7BdKtaOczdIF1gK_MCiEhqwUnFI,8891
50
+ torchzero/modules/adaptive/soap.py,sha256=cPRkgknFf26HfFss5YQ_RkLBdfy-nOvzFi7agVji3nk,12406
51
+ torchzero/modules/adaptive/sophia_h.py,sha256=3l91pYvJBK1pdNodgYsTTR1cMnkqe6zePWSYSx-CAXQ,7059
52
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
53
+ torchzero/modules/clipping/clipping.py,sha256=_VY78EmzZiYCirZZ99cCSnim5wNgQyeXaXAGxPQUnFk,14184
54
+ torchzero/modules/clipping/ema_clipping.py,sha256=D4NgXzXYMjK_SKQU3rVoOKzaCd9igGQg_7sXiGMgMqI,6750
55
+ torchzero/modules/clipping/growth_clipping.py,sha256=I1nk5xXBjk0BzWYzMC58LZHouY44myZNIUjM-duv7zc,6508
56
+ torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
57
+ torchzero/modules/conjugate_gradient/cg.py,sha256=YQD01wNnSpUdtqcf7egXHpbhFhRtuFD8ZhlNLvc_ZI0,14539
58
+ torchzero/modules/experimental/__init__.py,sha256=nI6WD-EDFfIav8G2mZuAznErweAqZ5C15qFKqY49ci8,728
59
+ torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
60
+ torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
61
+ torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
62
+ torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
63
+ torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
64
+ torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
65
+ torchzero/modules/experimental/l_infinity.py,sha256=nhYusM3YYbc0ptaSf9zlrsqY8EgxlHm9OejJ6VV0qtM,4750
66
+ torchzero/modules/experimental/newton_solver.py,sha256=nsknGw8XDHKyDNE2cFYxRI5qF_mw0vragPtmzTY4-aM,4069
67
+ torchzero/modules/experimental/newtonnewton.py,sha256=ryJV5eDqj4mEvIxERJIPd-0f57M0eFevH4ZAVHhnP4A,3892
68
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
69
+ torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
70
+ torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
71
+ torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
72
+ torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
73
+ torchzero/modules/grad_approximation/fdm.py,sha256=zx70GZDQmhe43bZP5Mbbl31xsMOsGO43kznoQDbqxJo,4372
74
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
75
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
76
+ torchzero/modules/grad_approximation/rfdm.py,sha256=mYS89DI7Rel10QsB9Q2cyrjo4ChhPD1kX00h0t6dc4A,19599
77
+ torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
78
+ torchzero/modules/least_squares/gn.py,sha256=Yl6HgZuZdba83ytzi5B8zBIRE7KfDJiQeO3kxZVjmBM,7146
79
+ torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
80
+ torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
81
+ torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
82
+ torchzero/modules/line_search/backtracking.py,sha256=laQAgy4a6JWovAbjpjGxFn323moU9r3jjv6I5R780sE,9047
83
+ torchzero/modules/line_search/interpolation.py,sha256=tHXlZD1MgfYaymhXY75k9CocltwSYDYZg6ENDCEUiss,4942
84
+ torchzero/modules/line_search/line_search.py,sha256=wqx65OkN3jkyuTuwCQtelVDDWbeWbPYV8Twi2jmrqaw,12926
85
+ torchzero/modules/line_search/scipy.py,sha256=xQ80h9cSyF4Iorq_1NoJglu_Bx4_KeojulBIxvwU6gQ,2836
86
+ torchzero/modules/line_search/strong_wolfe.py,sha256=KRvHnZ0q2nYFfLJry6-Iki936n7B9WwHYDqNgeBsIrs,14975
87
+ torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
88
+ torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
89
+ torchzero/modules/misc/escape.py,sha256=-7TTIZzxiJpVQ_jSFLy64LHijlJfC5y6THvdpNOfrWs,1928
90
+ torchzero/modules/misc/gradient_accumulation.py,sha256=DKJOKewtdvKcSlQEU2ZifB9awdqXU5IxGv68L8yumb0,2269
91
+ torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
92
+ torchzero/modules/misc/misc.py,sha256=lnz071OhYCrpQ-QO8yU77ba3NLZEgvJwJ9dFUDhxupk,15000
93
+ torchzero/modules/misc/multistep.py,sha256=Z0JL1PTRIwnLkBD0tAnxXjCsYJMAr3g3R79lz5DsHNQ,6779
94
+ torchzero/modules/misc/regularization.py,sha256=P5sPHItQcuuOvLioA5x5Mk2ms33E2brjfMaWG0Yzlok,5931
95
+ torchzero/modules/misc/split.py,sha256=PQvI5VPdQo1_Yr9ypn57ZBtkI3EiHovXvIjduoL-LBg,4286
96
+ torchzero/modules/misc/switch.py,sha256=dboHoc1Ijj9ziQEx2HzbT4mnGQH6OVs673BfcfPzLls,3674
97
+ torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
98
+ torchzero/modules/momentum/averaging.py,sha256=Q6WLwCJwgNY96YIfQXWpsX-2kDR7n0IOMDfZMvNVc9U,3035
99
+ torchzero/modules/momentum/cautious.py,sha256=3xlWuuMAWrJGnvtbdNYZNB0xjDeo-gUaOfyCqxCdMGI,8373
100
+ torchzero/modules/momentum/momentum.py,sha256=gxyOudXyABUBU4xodBvWm9Cr6SQgcwf5WEevQWhKjPA,4390
101
+ torchzero/modules/ops/__init__.py,sha256=xUYzWWLlSwaT8sw3dWywkALqI6YGCZgptWQJVy83HhM,1249
102
+ torchzero/modules/ops/accumulate.py,sha256=f-Uutg7gNFRobTc5YI9JlfFiSacXmg0gDhIwQNwZSZg,3439
103
+ torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
104
+ torchzero/modules/ops/higher_level.py,sha256=F_fv12U1AioRSq0I1FNa9pXYWsyOj-albyTgU9XniiM,8966
105
+ torchzero/modules/ops/multi.py,sha256=9TW9sDlhkewAIFk_YtIb_L9oppDCAh5LAvv3t7w77lM,8641
106
+ torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
107
+ torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
108
+ torchzero/modules/ops/utility.py,sha256=UkR3BfN_NBsZe78oERnknf6lmgeGNiEbtfUDNgU0YoQ,4423
109
+ torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
110
+ torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
111
+ torchzero/modules/projections/galore.py,sha256=70k30-2RJ3ncTNZpjxBhyYq1yJVFGfS2YUvU2trIK4o,258
112
+ torchzero/modules/projections/projection.py,sha256=MMQOplkELu-3fG0qLyx0-uLcDQKCzp7EiCuqTEdPHZ0,14277
113
+ torchzero/modules/quasi_newton/__init__.py,sha256=scn1qRpS1dtUv0u4tQPbjPy9Db66KrxyjS8YNDiKDYQ,543
114
+ torchzero/modules/quasi_newton/damping.py,sha256=EQzxyRcq-KSQP9RJA8UscDo0LNigNXmdFBhyeiKN2p0,2807
115
+ torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=Zx-tlFRa89GhoSP7RFJdLQJPiqPCL7rWaV7WJoQ1YCs,6930
116
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=tBmWXQ1AJE-yikRfx2g2xXRcTfIMzBEDCP_FPPpkMFM,11200
117
+ torchzero/modules/quasi_newton/lsr1.py,sha256=EozLObU28uTg88ZwV4SqBeFrQbJXqK9TfZnKqYON7rk,8512
118
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=B0712bqhrVNhtGiQAN6b4AgR__QdsVirLkMdOx1YGw4,45475
119
+ torchzero/modules/quasi_newton/sg2.py,sha256=b9Ly-BN4q82oki6wRKvDitBUfoTr6u3aMIQhpBak62E,7836
120
+ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6UpyRp5ydgdSg,134
121
+ torchzero/modules/restarts/restars.py,sha256=R-8eVQ1SqUhhZrr9CM2NcIlqs0HsFRaMwjXr6FyjD5w,9262
122
+ torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
123
+ torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
124
+ torchzero/modules/second_order/inm.py,sha256=hafMvqQBwwhTpWElwWouOj_BuMR5v7NacZg-6dpPNbA,3226
125
+ torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
126
+ torchzero/modules/second_order/newton.py,sha256=kGAsZCYd9uMDLUg_dhrS58JsSVjnZ2eqnJmOr-fghRo,11588
127
+ torchzero/modules/second_order/newton_cg.py,sha256=r9gb9hoxKwhQ1FIaZENZscnaPV_9fiJtmuhS7Dd-_7I,14810
128
+ torchzero/modules/second_order/nystrom.py,sha256=GXff6kAge8v-Pt-rCgivEIqXa2xxoEPsByBC9Rkx0kM,11397
129
+ torchzero/modules/second_order/rsn.py,sha256=uUvMXFh6tzciz92myBnsxF1rCGAY71oBPFC04mBGmV8,10369
130
+ torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
131
+ torchzero/modules/smoothing/laplacian.py,sha256=Cs351WVfapQjK4loYlAT6Sx0xIHnc89dhOVNnxUJiGI,5199
132
+ torchzero/modules/smoothing/sampling.py,sha256=tYWK0Rw1CRcultKbMK0C5-TUFXbffjR62338iYYtIck,12942
133
+ torchzero/modules/step_size/__init__.py,sha256=jG0qXpIn17oYXL8b34UjiEbkl002hj3FqJk1uQ5bkCg,136
134
+ torchzero/modules/step_size/adaptive.py,sha256=6YMnaHDtUgC3xdEx8dieDAKrEsYv51eP4craB4JpjP0,14600
135
+ torchzero/modules/step_size/lr.py,sha256=O3LkxvmeUiDjmH6-oILh7WCfYcC8b44yiQDfVLL5TeU,5927
136
+ torchzero/modules/termination/__init__.py,sha256=LkXBiOOYD4ce1Lemj0Vx9BCm_KhRTQTMvm-PD4lQwTs,344
137
+ torchzero/modules/termination/termination.py,sha256=lJXLmtA84JoK_QHhjBxfW9lkxGIxqg7cE7d755MeduE,6905
138
+ torchzero/modules/trust_region/__init__.py,sha256=kWke9FB41-EpjdXCPk8VBwZhpgYalOWSKDI1XWe0yYg,204
139
+ torchzero/modules/trust_region/cubic_regularization.py,sha256=kz4mwo2cD6X6qQKVvjAA1PHbcZy3qLBiDZVKDpY7fxM,6709
140
+ torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
141
+ torchzero/modules/trust_region/levenberg_marquardt.py,sha256=lwjNyrvdOrigjXuToM9qQq4AhpRHQuJ9oYCGyuPcWEg,5035
142
+ torchzero/modules/trust_region/trust_cg.py,sha256=mfTGsAZU0L46k2pdv3w663lXPCkNekbW2eTHQBED1iw,4453
143
+ torchzero/modules/trust_region/trust_region.py,sha256=oXMNIvboz0R_1J0Gfd4IvbnwZFl32csNVv-lTYGB0zk,12913
144
+ torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
145
+ torchzero/modules/variance_reduction/svrg.py,sha256=R3bULX3nw6M7hw4ZN0cpOvxvdy8CLPHYAStPY9qfDng,8804
146
+ torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
147
+ torchzero/modules/weight_decay/reinit.py,sha256=ml70eC4FXh-9G_Iq8tqoJ9R6NBxetROIGp1YaqfGfMw,3326
148
+ torchzero/modules/weight_decay/weight_decay.py,sha256=N9ozvYmEKFOpG84wWfgz7NifYQNwQTBwnXEtR7ygiN4,5352
149
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
150
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=P9Rsilde-SD6WwXQWDrJetZXw7NSowZznvCuFQkps8s,4693
151
+ torchzero/modules/zeroth_order/__init__.py,sha256=1ADUiOHVHzvIP4TpH7_ILmeW2heidfikbf6d5g_1RzY,18
152
+ torchzero/modules/zeroth_order/cd.py,sha256=fTLDtrX4nLX1vG7_isxIUtAkYnxCKS_cCECnPyHTQq0,5017
153
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
154
+ torchzero/optim/root.py,sha256=MnXytsgTjDKhYw3UKux1O-g4vuHsFQu2Emsq5zphu_8,2308
155
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
156
+ torchzero/optim/utility/split.py,sha256=uNnKA40OiXUlr-vlHuU_rLEUfXQXgvr6Cd9yGzjWJiA,1702
157
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
158
+ torchzero/optim/wrappers/directsearch.py,sha256=7f1Zy9ZxftgIFxJ9_1snyMjp8ypB4ULyddOqgxdwkEg,9806
159
+ torchzero/optim/wrappers/fcmaes.py,sha256=Bobx87TWinICd3SyNaKhoplpNbrLdD5vR1J1TZw69Cs,3437
160
+ torchzero/optim/wrappers/mads.py,sha256=Gm1msCsVNJDq-WS913peZ0TzXU_7wgP_sPmMfEK0VIY,2326
161
+ torchzero/optim/wrappers/moors.py,sha256=AjEeovPEvxEGJwbhOaOCrCIxPgwOHilkj-vWb-88Y_0,2073
162
+ torchzero/optim/wrappers/nevergrad.py,sha256=Tmo-9RaFH7OXTJU36UMSALJro2K0AuQk6ow9g9ZoPoU,4890
163
+ torchzero/optim/wrappers/nlopt.py,sha256=FzU_3VuNN0_ngT_FlQs7ZOL_RZcfeFtJNmklqvyzTkw,8694
164
+ torchzero/optim/wrappers/optuna.py,sha256=31K-oeffoKfBc-LmGY5F33g41qlK-RDjZsZO-mZbSRs,2211
165
+ torchzero/optim/wrappers/pybobyqa.py,sha256=1eH3HW2HVyNneZ5vB9ch1hb0OdoYrnk4pJmQVRIOAw4,6309
166
+ torchzero/optim/wrappers/wrapper.py,sha256=eEZnqKRc2cSFVVXNyOGGReMX3MvY0eoHSRjFsQ8pCSw,4380
167
+ torchzero/optim/wrappers/scipy/__init__.py,sha256=lAciW9WTZ8JZ9ZFunyi78cOa6URhTL-rIjk1AdqbrxM,262
168
+ torchzero/optim/wrappers/scipy/basin_hopping.py,sha256=11FuVLYDJxbyxW-44UO-JRef7MErGFfM_PqqqADA8RQ,4435
169
+ torchzero/optim/wrappers/scipy/brute.py,sha256=MD5fWcUo0-COrq5_OKM8X5SSMZaN--4ThW8mdISLX00,1213
170
+ torchzero/optim/wrappers/scipy/differential_evolution.py,sha256=GBJSh0K5ueZX3fY9vxTntQnuIoLDjOQyOj6TMCmTGBg,2669
171
+ torchzero/optim/wrappers/scipy/direct.py,sha256=lTkH8aRY7sWkhrtQ8Z-m8fSdhTpsLxOIQSJprSqThFA,1821
172
+ torchzero/optim/wrappers/scipy/dual_annealing.py,sha256=mIdXy8F3YEfrQCXNxps8T33yRDuqJxaaIjzB0wlKo90,4117
173
+ torchzero/optim/wrappers/scipy/experimental.py,sha256=-_82LtrDVg0qRg6E1coXugsCGGcVNXgxv6m5Anaaw24,4833
174
+ torchzero/optim/wrappers/scipy/minimize.py,sha256=6O2Js2ecGZkNE3mqXJR9Rqp5xlh3KIVYBMF82k5hrws,6491
175
+ torchzero/optim/wrappers/scipy/sgho.py,sha256=mKG9uG-zaEzB6mbiZ25FnXdRcfvGYYOHapuHniyJ8r8,4054
176
+ torchzero/utils/__init__.py,sha256=tpy9ti5Ub5d1zQPFODZ4PjmFKNzoZTd-NByd_snlYtk,761
177
+ torchzero/utils/compile.py,sha256=dY9ioWQvJt71HQN_z4kXX4KgtJ0xOW0MKNlZpHD6118,5130
178
+ torchzero/utils/derivatives.py,sha256=Crma5mVdI4hTPMkoMzEfMTAyCXZIUuzeEzkzQMpBNOY,18089
179
+ torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,3130
180
+ torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
181
+ torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
182
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
183
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
184
+ torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
185
+ torchzero/utils/tensorlist.py,sha256=4rN8gm967pPmtO5kotXqIX7Mal0ps-IHkGBybfeWY4M,56357
186
+ torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
187
+ torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
188
+ torchzero-0.4.0.dist-info/METADATA,sha256=WbRYUNngrfEBuPLquDDkrl31aqVDq4mTacBurEfdgmI,564
189
+ torchzero-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
190
+ torchzero-0.4.0.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
191
+ torchzero-0.4.0.dist-info/RECORD,,
tests/test_vars.py DELETED
@@ -1,185 +0,0 @@
1
- import pytest
2
- import torch
3
- from torchzero.core.module import Var
4
- from torchzero.utils.tensorlist import TensorList
5
-
6
- @torch.no_grad
7
- def test_var_get_loss():
8
-
9
- # ---------------------------- test that it works ---------------------------- #
10
- params = [torch.tensor(2.0, requires_grad=True)]
11
- evaluated = False
12
-
13
- def closure_1(backward=True):
14
- assert not backward, 'backward = True'
15
-
16
- # ensure closure only evaluates once
17
- nonlocal evaluated
18
- assert evaluated is False, 'closure was evaluated twice'
19
- evaluated = True
20
-
21
- loss = params[0]**2
22
- if backward:
23
- params[0].grad = None
24
- loss.backward()
25
- else:
26
- assert not loss.requires_grad, "loss requires grad with backward=False"
27
- return loss
28
-
29
- var = Var(params=params, closure=closure_1, model=None, current_step=0)
30
-
31
- assert var.loss is None, var.loss
32
-
33
- assert (loss := var.get_loss(backward=False)) == 4.0, loss
34
- assert evaluated, evaluated
35
- assert loss is var.loss
36
- assert var.loss == 4.0
37
- assert var.loss_approx == 4.0
38
- assert var.grad is None, var.grad
39
-
40
- # reevaluate, which should just return already evaluated loss
41
- assert (loss := var.get_loss(backward=False)) == 4.0, loss
42
- assert var.grad is None, var.grad
43
-
44
-
45
- # ----------------------- test that backward=True works ---------------------- #
46
- params = [torch.tensor(3.0, requires_grad=True)]
47
- evaluated = False
48
-
49
- def closure_2(backward=True):
50
- # ensure closure only evaluates once
51
- nonlocal evaluated
52
- assert evaluated is False, 'closure was evaluated twice'
53
- evaluated = True
54
-
55
- loss = params[0] * 2
56
- if backward:
57
- assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
58
- params[0].grad = None
59
- loss.backward()
60
- else:
61
- assert not loss.requires_grad, "loss requires grad with backward=False"
62
- return loss
63
-
64
- var = Var(params=params, closure=closure_2, model=None, current_step=0)
65
- assert var.grad is None, var.grad
66
- assert (loss := var.get_loss(backward=True)) == 6.0, loss
67
- assert var.grad is not None
68
- assert var.grad[0] == 2.0, var.grad
69
-
70
- # reevaluate, which should just return already evaluated loss
71
- assert (loss := var.get_loss(backward=True)) == 6.0, loss
72
- assert var.grad[0] == 2.0, var.grad
73
-
74
- # get grad, which should just return already evaluated grad
75
- assert (grad := var.get_grad())[0] == 2.0, grad
76
- assert grad is var.grad, grad
77
-
78
- # get update, which should create and return cloned grad
79
- assert var.update is None
80
- assert (update := var.get_update())[0] == 2.0, update
81
- assert update is var.update
82
- assert update is not var.grad
83
- assert var.grad is not None
84
- assert update[0] == var.grad[0]
85
-
86
- @torch.no_grad
87
- def test_var_get_grad():
88
- params = [torch.tensor(2.0, requires_grad=True)]
89
- evaluated = False
90
-
91
- def closure(backward=True):
92
- # ensure closure only evaluates once
93
- nonlocal evaluated
94
- assert evaluated is False, 'closure was evaluated twice'
95
- evaluated = True
96
-
97
- loss = params[0]**2
98
- if backward:
99
- assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
100
- params[0].grad = None
101
- loss.backward()
102
- else:
103
- assert not loss.requires_grad, "loss requires grad with backward=False"
104
- return loss
105
-
106
- var = Var(params=params, closure=closure, model=None, current_step=0)
107
- assert (grad := var.get_grad())[0] == 4.0, grad
108
- assert grad is var.grad
109
-
110
- assert var.loss == 4.0
111
- assert (loss := var.get_loss(backward=False)) == 4.0, loss
112
- assert (loss := var.get_loss(backward=True)) == 4.0, loss
113
- assert var.loss_approx == 4.0
114
-
115
- assert var.update is None, var.update
116
- assert (update := var.get_update())[0] == 4.0, update
117
-
118
- @torch.no_grad
119
- def test_var_get_update():
120
- params = [torch.tensor(2.0, requires_grad=True)]
121
- evaluated = False
122
-
123
- def closure(backward=True):
124
- # ensure closure only evaluates once
125
- nonlocal evaluated
126
- assert evaluated is False, 'closure was evaluated twice'
127
- evaluated = True
128
-
129
- loss = params[0]**2
130
- if backward:
131
- assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
132
- params[0].grad = None
133
- loss.backward()
134
- else:
135
- assert not loss.requires_grad, "loss requires grad with backward=False"
136
- return loss
137
-
138
- var = Var(params=params, closure=closure, model=None, current_step=0)
139
- assert var.update is None, var.update
140
- assert (update := var.get_update())[0] == 4.0, update
141
- assert update is var.update
142
-
143
- assert (grad := var.get_grad())[0] == 4.0, grad
144
- assert grad is var.grad
145
- assert grad is not update
146
-
147
- assert var.loss == 4.0
148
- assert (loss := var.get_loss(backward=False)) == 4.0, loss
149
- assert (loss := var.get_loss(backward=True)) == 4.0, loss
150
- assert var.loss_approx == 4.0
151
-
152
- assert (update := var.get_update())[0] == 4.0, update
153
-
154
-
155
- def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
156
- for k,v in v1.__dict__.items():
157
- if not k.startswith('__'):
158
- # if k == 'post_step_hooks': continue
159
- if k == 'storage': continue
160
- if k == 'update' and clone_update:
161
- if v1.update is None or v2.update is None:
162
- assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
163
- else:
164
- assert (TensorList(v1.update) == TensorList(v2.update)).global_all()
165
- assert v1.update is not v2.update
166
- else:
167
- assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
168
-
169
- def test_var_clone():
170
- model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
171
- def closure(backward): return 1
172
- var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
173
-
174
- _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
175
- _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
176
-
177
- var.grad = TensorList(torch.randn(5))
178
- _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
179
- _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
180
-
181
- var.update = TensorList(torch.randn(5) * 2)
182
- var.loss = torch.randn(1)
183
- var.loss_approx = var.loss
184
- _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
185
- _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
@@ -1,160 +0,0 @@
1
- from collections.abc import Sequence
2
- from functools import partial
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
- import torch
7
-
8
- from ...core import Target, Transform
9
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
10
- from ..functional import ema_, ema_sq_, sqrt_ema_sq_
11
- from ..momentum.momentum import nag_
12
- from ..ops.higher_level import EMASquared, SqrtEMASquared
13
-
14
-
15
- def precentered_ema_sq_(
16
- tensors: TensorList,
17
- exp_avg_: TensorList,
18
- exp_avg_sq_: TensorList,
19
- beta1: float | NumberList,
20
- beta2: float | NumberList,
21
- step: int,
22
- min_step: int,
23
- pow: float,
24
- max_exp_avg_sq_: TensorList | None,
25
- ):
26
- """
27
- Squared EMA of (update - 1st EMA). Starts taking effect after `min_step` to avoid division by epsilon.
28
-
29
- returns `exp_avg_sq_` or `max_exp_avg_sq_`.
30
- """
31
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0, lerp=False)
32
-
33
- if step < min_step: centered_update = tensors
34
- else: centered_update = tensors - exp_avg_
35
-
36
- exp_avg_sq_=ema_sq_(
37
- centered_update,
38
- exp_avg_sq_=exp_avg_sq_,
39
- beta=beta2,
40
- pow=pow,
41
- max_exp_avg_sq_=max_exp_avg_sq_,
42
- )
43
- return exp_avg_sq_
44
-
45
- class PrecenteredEMASquared(Transform):
46
- """Maintains un-squared EMA, the updates are centered by it before being fed into squared EMA."""
47
- def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
48
- defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
49
- super().__init__(defaults, uses_grad=False, target=target)
50
-
51
- @torch.no_grad
52
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
53
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
54
-
55
- beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
56
- amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(settings[0])
57
-
58
- if amsgrad:
59
- exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
60
- else:
61
- exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
62
- max_exp_avg_sq = None
63
-
64
- return precentered_ema_sq_(
65
- TensorList(tensors),
66
- exp_avg_ = exp_avg,
67
- exp_avg_sq_=exp_avg_sq,
68
- beta1=beta1,
69
- beta2=beta2,
70
- step = step,
71
- min_step=min_step,
72
- pow=pow,
73
- max_exp_avg_sq_=max_exp_avg_sq,
74
- ).clone()
75
-
76
-
77
- def nag_ema_sq_(
78
- tensors: TensorList,
79
- exp_avg_sq_: TensorList,
80
- beta: float | NumberList,
81
- max_exp_avg_sq_: TensorList | None,
82
- pow: float,
83
- lerp:bool=True,
84
- ):
85
- """
86
- Nesterov EMA of squared tensors.
87
-
88
- Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
89
- """
90
- if pow == 1: tensors = tensors.abs()
91
- elif pow%2 == 0: tensors = tensors.pow(pow)
92
- else: tensors = tensors.pow(pow).abs()
93
-
94
- exp_avg_sq_=nag_(tensors,velocity_=exp_avg_sq_,momentum=beta,dampening=0,lerp=lerp,)
95
-
96
- # AMSGrad
97
- if max_exp_avg_sq_ is not None:
98
- max_exp_avg_sq_.maximum_(exp_avg_sq_)
99
- exp_avg_sq_ = max_exp_avg_sq_
100
-
101
- return exp_avg_sq_
102
-
103
- def sqrt_nag_ema_sq_(
104
- tensors: TensorList,
105
- exp_avg_sq_: TensorList,
106
- beta: float | NumberList,
107
- max_exp_avg_sq_: TensorList | None,
108
- debiased: bool,
109
- step: int,
110
- pow: float,
111
- lerp:bool=False,
112
- ):
113
- """
114
- Square root of nesterov EMA of squared tensors.
115
-
116
- Returns new tensors.
117
- """
118
- return sqrt_ema_sq_(tensors=tensors,exp_avg_sq_=exp_avg_sq_,beta=beta,max_exp_avg_sq_=max_exp_avg_sq_,
119
- pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
120
-
121
- class NesterovEMASquared(EMASquared):
122
- """squared momentum with nesterov momentum rule"""
123
- EMA_SQ_FN = staticmethod(nag_ema_sq_)
124
-
125
- class SqrtNesterovEMASquared(SqrtEMASquared):
126
- """square root of squared momentum with nesterov momentum rule"""
127
- SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
128
-
129
-
130
- def coordinate_momentum_(
131
- tensors: TensorList,
132
- velocity_: TensorList,
133
- p: float | NumberList,
134
- ):
135
- """
136
- sets `velocity_` to p% random values from `tensors`.
137
-
138
- Returns `velocity_`
139
- """
140
- mask = tensors.bernoulli_like(p).as_bool()
141
- velocity_.masked_set_(mask, tensors)
142
- return velocity_
143
-
144
-
145
- class CoordinateMomentum(Transform):
146
- """Maintains a momentum buffer, on each step each value in the buffer has :code:`p` chance to be updated with the new value.
147
-
148
- Args:
149
- p (float, optional): _description_. Defaults to 0.1.
150
- target (Target, optional): _description_. Defaults to 'update'.
151
- """
152
- def __init__(self, p: float = 0.1, target: Target = 'update'):
153
- defaults = dict(p=p)
154
- super().__init__(defaults, uses_grad=False, target=target)
155
-
156
- @torch.no_grad
157
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
158
- p = NumberList(s['p'] for s in settings)
159
- velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
160
- return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
@@ -1 +0,0 @@
1
- from .higher_order_newton import HigherOrderNewton