torchzero 0.3.10__tar.gz → 0.3.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. torchzero-0.3.13/PKG-INFO +14 -0
  2. {torchzero-0.3.10 → torchzero-0.3.13}/pyproject.toml +2 -2
  3. {torchzero-0.3.10 → torchzero-0.3.13}/tests/test_identical.py +2 -3
  4. {torchzero-0.3.10 → torchzero-0.3.13}/tests/test_opts.py +140 -100
  5. {torchzero-0.3.10 → torchzero-0.3.13}/tests/test_tensorlist.py +8 -7
  6. {torchzero-0.3.10 → torchzero-0.3.13}/tests/test_vars.py +1 -0
  7. torchzero-0.3.13/torchzero/__init__.py +4 -0
  8. torchzero-0.3.13/torchzero/core/__init__.py +2 -0
  9. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/core/module.py +335 -50
  10. torchzero-0.3.13/torchzero/core/reformulation.py +65 -0
  11. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/core/transform.py +197 -70
  12. torchzero-0.3.13/torchzero/modules/__init__.py +23 -0
  13. torchzero-0.3.13/torchzero/modules/adaptive/__init__.py +30 -0
  14. torchzero-0.3.13/torchzero/modules/adaptive/adagrad.py +356 -0
  15. torchzero-0.3.13/torchzero/modules/adaptive/adahessian.py +224 -0
  16. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/adam.py +6 -8
  17. torchzero-0.3.13/torchzero/modules/adaptive/adan.py +96 -0
  18. torchzero-0.3.13/torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  19. torchzero-0.3.13/torchzero/modules/adaptive/aegd.py +54 -0
  20. torchzero-0.3.13/torchzero/modules/adaptive/esgd.py +171 -0
  21. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/lion.py +1 -1
  22. torchzero-0.3.10/torchzero/modules/experimental/spectral.py → torchzero-0.3.13/torchzero/modules/adaptive/lmadagrad.py +94 -71
  23. torchzero-0.3.13/torchzero/modules/adaptive/mars.py +79 -0
  24. torchzero-0.3.13/torchzero/modules/adaptive/matrix_momentum.py +146 -0
  25. torchzero-0.3.13/torchzero/modules/adaptive/msam.py +188 -0
  26. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/muon.py +29 -5
  27. torchzero-0.3.13/torchzero/modules/adaptive/natural_gradient.py +175 -0
  28. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/orthograd.py +1 -1
  29. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/rmsprop.py +7 -4
  30. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/rprop.py +42 -10
  31. torchzero-0.3.13/torchzero/modules/adaptive/sam.py +163 -0
  32. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/shampoo.py +47 -9
  33. {torchzero-0.3.10/torchzero/modules/optimizers → torchzero-0.3.13/torchzero/modules/adaptive}/soap.py +52 -65
  34. torchzero-0.3.13/torchzero/modules/adaptive/sophia_h.py +185 -0
  35. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/clipping/clipping.py +115 -25
  36. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/clipping/ema_clipping.py +31 -17
  37. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/clipping/growth_clipping.py +8 -7
  38. torchzero-0.3.13/torchzero/modules/conjugate_gradient/__init__.py +11 -0
  39. torchzero-0.3.13/torchzero/modules/conjugate_gradient/cg.py +355 -0
  40. torchzero-0.3.13/torchzero/modules/experimental/__init__.py +18 -0
  41. {torchzero-0.3.10/torchzero/modules/projections → torchzero-0.3.13/torchzero/modules/experimental}/dct.py +11 -11
  42. {torchzero-0.3.10/torchzero/modules/projections → torchzero-0.3.13/torchzero/modules/experimental}/fft.py +10 -10
  43. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/experimental/gradmin.py +4 -3
  44. torchzero-0.3.13/torchzero/modules/experimental/l_infinity.py +111 -0
  45. torchzero-0.3.10/torchzero/modules/momentum/experimental.py → torchzero-0.3.13/torchzero/modules/experimental/momentum.py +5 -42
  46. torchzero-0.3.13/torchzero/modules/experimental/newton_solver.py +150 -0
  47. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/experimental/newtonnewton.py +32 -15
  48. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  49. torchzero-0.3.13/torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  50. torchzero-0.3.10/torchzero/modules/projections/structural.py → torchzero-0.3.13/torchzero/modules/experimental/structural_projections.py +13 -55
  51. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/functional.py +52 -6
  52. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/grad_approximation/fdm.py +30 -4
  53. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  54. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  55. torchzero-0.3.13/torchzero/modules/grad_approximation/rfdm.py +541 -0
  56. torchzero-0.3.13/torchzero/modules/higher_order/__init__.py +1 -0
  57. torchzero-0.3.13/torchzero/modules/higher_order/higher_order_newton.py +327 -0
  58. torchzero-0.3.13/torchzero/modules/least_squares/__init__.py +1 -0
  59. torchzero-0.3.13/torchzero/modules/least_squares/gn.py +161 -0
  60. torchzero-0.3.13/torchzero/modules/line_search/__init__.py +5 -0
  61. torchzero-0.3.13/torchzero/modules/line_search/_polyinterp.py +289 -0
  62. torchzero-0.3.13/torchzero/modules/line_search/adaptive.py +124 -0
  63. torchzero-0.3.13/torchzero/modules/line_search/backtracking.py +243 -0
  64. torchzero-0.3.13/torchzero/modules/line_search/line_search.py +330 -0
  65. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/line_search/scipy.py +3 -3
  66. torchzero-0.3.13/torchzero/modules/line_search/strong_wolfe.py +377 -0
  67. torchzero-0.3.13/torchzero/modules/misc/__init__.py +35 -0
  68. torchzero-0.3.13/torchzero/modules/misc/debug.py +48 -0
  69. torchzero-0.3.13/torchzero/modules/misc/escape.py +62 -0
  70. torchzero-0.3.13/torchzero/modules/misc/gradient_accumulation.py +136 -0
  71. torchzero-0.3.13/torchzero/modules/misc/homotopy.py +59 -0
  72. torchzero-0.3.13/torchzero/modules/misc/misc.py +383 -0
  73. torchzero-0.3.13/torchzero/modules/misc/multistep.py +194 -0
  74. torchzero-0.3.13/torchzero/modules/misc/regularization.py +167 -0
  75. torchzero-0.3.13/torchzero/modules/misc/split.py +123 -0
  76. {torchzero-0.3.10/torchzero/modules/ops → torchzero-0.3.13/torchzero/modules/misc}/switch.py +45 -4
  77. torchzero-0.3.13/torchzero/modules/momentum/__init__.py +10 -0
  78. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/momentum/averaging.py +9 -9
  79. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/momentum/cautious.py +51 -19
  80. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/momentum/momentum.py +37 -2
  81. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/__init__.py +11 -31
  82. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/accumulate.py +6 -10
  83. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/binary.py +81 -34
  84. torchzero-0.3.10/torchzero/modules/momentum/ema.py → torchzero-0.3.13/torchzero/modules/ops/higher_level.py +16 -39
  85. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/multi.py +82 -21
  86. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/reduce.py +16 -8
  87. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/unary.py +29 -13
  88. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/ops/utility.py +30 -18
  89. torchzero-0.3.13/torchzero/modules/projections/__init__.py +3 -0
  90. torchzero-0.3.13/torchzero/modules/projections/cast.py +51 -0
  91. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/projections/galore.py +3 -1
  92. torchzero-0.3.13/torchzero/modules/projections/projection.py +338 -0
  93. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/quasi_newton/__init__.py +9 -14
  94. torchzero-0.3.13/torchzero/modules/quasi_newton/damping.py +105 -0
  95. torchzero-0.3.13/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  96. torchzero-0.3.13/torchzero/modules/quasi_newton/lbfgs.py +342 -0
  97. torchzero-0.3.13/torchzero/modules/quasi_newton/lsr1.py +253 -0
  98. torchzero-0.3.13/torchzero/modules/quasi_newton/quasi_newton.py +1231 -0
  99. torchzero-0.3.13/torchzero/modules/restarts/__init__.py +7 -0
  100. torchzero-0.3.13/torchzero/modules/restarts/restars.py +252 -0
  101. torchzero-0.3.13/torchzero/modules/second_order/__init__.py +4 -0
  102. torchzero-0.3.13/torchzero/modules/second_order/multipoint.py +238 -0
  103. torchzero-0.3.13/torchzero/modules/second_order/newton.py +383 -0
  104. torchzero-0.3.13/torchzero/modules/second_order/newton_cg.py +435 -0
  105. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/second_order/nystrom.py +104 -1
  106. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/smoothing/__init__.py +1 -1
  107. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/smoothing/laplacian.py +14 -4
  108. torchzero-0.3.13/torchzero/modules/smoothing/sampling.py +300 -0
  109. torchzero-0.3.13/torchzero/modules/step_size/__init__.py +2 -0
  110. torchzero-0.3.13/torchzero/modules/step_size/adaptive.py +387 -0
  111. torchzero-0.3.13/torchzero/modules/step_size/lr.py +154 -0
  112. torchzero-0.3.13/torchzero/modules/termination/__init__.py +14 -0
  113. torchzero-0.3.13/torchzero/modules/termination/termination.py +207 -0
  114. torchzero-0.3.13/torchzero/modules/trust_region/__init__.py +5 -0
  115. torchzero-0.3.13/torchzero/modules/trust_region/cubic_regularization.py +170 -0
  116. torchzero-0.3.13/torchzero/modules/trust_region/dogleg.py +92 -0
  117. torchzero-0.3.13/torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  118. torchzero-0.3.13/torchzero/modules/trust_region/trust_cg.py +97 -0
  119. torchzero-0.3.13/torchzero/modules/trust_region/trust_region.py +350 -0
  120. torchzero-0.3.13/torchzero/modules/variance_reduction/__init__.py +1 -0
  121. torchzero-0.3.13/torchzero/modules/variance_reduction/svrg.py +208 -0
  122. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/weight_decay/__init__.py +1 -1
  123. torchzero-0.3.13/torchzero/modules/weight_decay/weight_decay.py +169 -0
  124. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/wrappers/optim_wrapper.py +29 -1
  125. torchzero-0.3.13/torchzero/modules/zeroth_order/__init__.py +1 -0
  126. torchzero-0.3.13/torchzero/modules/zeroth_order/cd.py +359 -0
  127. torchzero-0.3.13/torchzero/optim/root.py +65 -0
  128. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/utility/split.py +8 -8
  129. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/directsearch.py +39 -3
  130. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/fcmaes.py +24 -15
  131. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/mads.py +5 -6
  132. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/nevergrad.py +16 -1
  133. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/nlopt.py +0 -2
  134. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/optuna.py +3 -3
  135. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/scipy.py +86 -25
  136. torchzero-0.3.13/torchzero/utils/__init__.py +59 -0
  137. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/compile.py +1 -1
  138. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/derivatives.py +126 -114
  139. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/linalg/__init__.py +9 -2
  140. torchzero-0.3.13/torchzero/utils/linalg/linear_operator.py +329 -0
  141. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/linalg/matrix_funcs.py +2 -2
  142. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/linalg/orthogonalize.py +2 -1
  143. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/linalg/qr.py +2 -2
  144. torchzero-0.3.13/torchzero/utils/linalg/solve.py +480 -0
  145. torchzero-0.3.13/torchzero/utils/metrics.py +83 -0
  146. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/numberlist.py +2 -0
  147. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/python_tools.py +16 -0
  148. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/tensorlist.py +134 -51
  149. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/torch_tools.py +9 -4
  150. torchzero-0.3.13/torchzero.egg-info/PKG-INFO +14 -0
  151. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero.egg-info/SOURCES.txt +71 -45
  152. torchzero-0.3.10/LICENSE +0 -21
  153. torchzero-0.3.10/PKG-INFO +0 -379
  154. torchzero-0.3.10/README.md +0 -340
  155. torchzero-0.3.10/docs/source/conf.py +0 -57
  156. torchzero-0.3.10/torchzero/__init__.py +0 -4
  157. torchzero-0.3.10/torchzero/core/__init__.py +0 -2
  158. torchzero-0.3.10/torchzero/modules/__init__.py +0 -14
  159. torchzero-0.3.10/torchzero/modules/experimental/__init__.py +0 -24
  160. torchzero-0.3.10/torchzero/modules/experimental/absoap.py +0 -250
  161. torchzero-0.3.10/torchzero/modules/experimental/adadam.py +0 -112
  162. torchzero-0.3.10/torchzero/modules/experimental/adamY.py +0 -125
  163. torchzero-0.3.10/torchzero/modules/experimental/adasoap.py +0 -172
  164. torchzero-0.3.10/torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  165. torchzero-0.3.10/torchzero/modules/experimental/eigendescent.py +0 -117
  166. torchzero-0.3.10/torchzero/modules/experimental/etf.py +0 -172
  167. torchzero-0.3.10/torchzero/modules/experimental/newton_solver.py +0 -88
  168. torchzero-0.3.10/torchzero/modules/experimental/soapy.py +0 -163
  169. torchzero-0.3.10/torchzero/modules/experimental/structured_newton.py +0 -111
  170. torchzero-0.3.10/torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  171. torchzero-0.3.10/torchzero/modules/experimental/tada.py +0 -38
  172. torchzero-0.3.10/torchzero/modules/grad_approximation/rfdm.py +0 -272
  173. torchzero-0.3.10/torchzero/modules/higher_order/__init__.py +0 -1
  174. torchzero-0.3.10/torchzero/modules/higher_order/higher_order_newton.py +0 -256
  175. torchzero-0.3.10/torchzero/modules/line_search/__init__.py +0 -5
  176. torchzero-0.3.10/torchzero/modules/line_search/backtracking.py +0 -205
  177. torchzero-0.3.10/torchzero/modules/line_search/line_search.py +0 -181
  178. torchzero-0.3.10/torchzero/modules/line_search/strong_wolfe.py +0 -249
  179. torchzero-0.3.10/torchzero/modules/line_search/trust_region.py +0 -73
  180. torchzero-0.3.10/torchzero/modules/lr/__init__.py +0 -2
  181. torchzero-0.3.10/torchzero/modules/lr/adaptive.py +0 -93
  182. torchzero-0.3.10/torchzero/modules/lr/lr.py +0 -63
  183. torchzero-0.3.10/torchzero/modules/momentum/__init__.py +0 -14
  184. torchzero-0.3.10/torchzero/modules/momentum/matrix_momentum.py +0 -166
  185. torchzero-0.3.10/torchzero/modules/ops/debug.py +0 -25
  186. torchzero-0.3.10/torchzero/modules/ops/misc.py +0 -418
  187. torchzero-0.3.10/torchzero/modules/ops/split.py +0 -75
  188. torchzero-0.3.10/torchzero/modules/optimizers/__init__.py +0 -18
  189. torchzero-0.3.10/torchzero/modules/optimizers/adagrad.py +0 -155
  190. torchzero-0.3.10/torchzero/modules/optimizers/sophia_h.py +0 -129
  191. torchzero-0.3.10/torchzero/modules/projections/__init__.py +0 -5
  192. torchzero-0.3.10/torchzero/modules/projections/projection.py +0 -244
  193. torchzero-0.3.10/torchzero/modules/quasi_newton/cg.py +0 -268
  194. torchzero-0.3.10/torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  195. torchzero-0.3.10/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  196. torchzero-0.3.10/torchzero/modules/quasi_newton/lbfgs.py +0 -229
  197. torchzero-0.3.10/torchzero/modules/quasi_newton/lsr1.py +0 -174
  198. torchzero-0.3.10/torchzero/modules/quasi_newton/olbfgs.py +0 -196
  199. torchzero-0.3.10/torchzero/modules/quasi_newton/quasi_newton.py +0 -683
  200. torchzero-0.3.10/torchzero/modules/second_order/__init__.py +0 -3
  201. torchzero-0.3.10/torchzero/modules/second_order/newton.py +0 -159
  202. torchzero-0.3.10/torchzero/modules/second_order/newton_cg.py +0 -85
  203. torchzero-0.3.10/torchzero/modules/smoothing/gaussian.py +0 -164
  204. torchzero-0.3.10/torchzero/modules/weight_decay/weight_decay.py +0 -86
  205. torchzero-0.3.10/torchzero/utils/__init__.py +0 -23
  206. torchzero-0.3.10/torchzero/utils/linalg/solve.py +0 -169
  207. torchzero-0.3.10/torchzero.egg-info/PKG-INFO +0 -379
  208. {torchzero-0.3.10 → torchzero-0.3.13}/setup.cfg +0 -0
  209. {torchzero-0.3.10 → torchzero-0.3.13}/tests/test_module.py +0 -0
  210. {torchzero-0.3.10 → torchzero-0.3.13}/tests/test_utils_optimizer.py +0 -0
  211. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/clipping/__init__.py +0 -0
  212. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/experimental/curveball.py +0 -0
  213. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/grad_approximation/__init__.py +0 -0
  214. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/modules/wrappers/__init__.py +0 -0
  215. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/__init__.py +0 -0
  216. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/utility/__init__.py +0 -0
  217. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/optim/wrappers/__init__.py +0 -0
  218. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/linalg/benchmark.py +0 -0
  219. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/linalg/svd.py +0 -0
  220. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/ops.py +0 -0
  221. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/optimizer.py +0 -0
  222. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/optuna_tools.py +0 -0
  223. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero/utils/params.py +0 -0
  224. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero.egg-info/dependency_links.txt +0 -0
  225. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero.egg-info/requires.txt +0 -0
  226. {torchzero-0.3.10 → torchzero-0.3.13}/torchzero.egg-info/top_level.txt +0 -0
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchzero
3
+ Version: 0.3.13
4
+ Summary: Modular optimization library for PyTorch.
5
+ Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/inikishev/torchzero
7
+ Project-URL: Repository, https://github.com/inikishev/torchzero
8
+ Project-URL: Issues, https://github.com/inikishev/torchzero/isses
9
+ Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: torch
13
+ Requires-Dist: numpy
14
+ Requires-Dist: typing_extensions
@@ -1,5 +1,5 @@
1
1
  # NEW VERSION TUTORIAL FOR MYSELF
2
- # STEP 1 - COMMIT NEW CHANGES BUT DON'T PUSH THEM YET
2
+ # STEP 1 - COMMIT NEW CHANGES AND PUSH THEM
3
3
  # STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
4
4
  # STEP 3 - CREATE TAG WITH THAT VERSION
5
5
  # STEP 4 - PUSH (SYNC) CHANGES
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.10"
16
+ version = "0.3.13"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
96
96
 
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
- # torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- # pytorch applies debiasing separately so it is applied before epsilo
99
+ torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
101
100
  tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
102
101
  tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
103
102
  tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
133
132
  tz.m.Debias2(beta=0.999),
134
133
  tz.m.Add(1e-8)]
135
134
  ))
136
- tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
135
+ tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
137
136
 
138
137
  _assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
139
138
  for fn in tz_fns:
@@ -56,14 +56,17 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
56
56
  if use_closure:
57
57
  def closure(backward=True):
58
58
  loss = objective()
59
+ losses.append(loss.detach())
59
60
  if backward:
60
61
  opt.zero_grad()
61
62
  loss.backward()
62
63
  return loss
63
- loss = opt.step(closure)
64
- assert loss is not None
65
- assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
66
- losses.append(loss)
64
+ ret = opt.step(closure)
65
+ assert ret is not None # the return should be the loss
66
+ with torch.no_grad():
67
+ loss = objective() # in case f(x_0) is not evaluated
68
+ assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
+ losses.append(loss.detach())
67
70
 
68
71
  else:
69
72
  loss = objective()
@@ -71,7 +74,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
71
74
  loss.backward()
72
75
  opt.step()
73
76
  assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
74
- losses.append(loss)
77
+ losses.append(loss.detach())
75
78
 
76
79
  losses.append(objective())
77
80
  return torch.stack(losses).nan_to_num(0,10000,10000).min()
@@ -292,42 +295,42 @@ FDM_central2 = Run(
292
295
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
293
296
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
294
297
  needs_closure=True,
295
- func='booth', steps=50, loss=1e-7, merge_invariant=True,
298
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
296
299
  sphere_steps=2, sphere_loss=340,
297
300
  )
298
301
  FDM_forward2 = Run(
299
302
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
300
303
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
301
304
  needs_closure=True,
302
- func='booth', steps=50, loss=1e-7, merge_invariant=True,
305
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
303
306
  sphere_steps=2, sphere_loss=340,
304
307
  )
305
308
  FDM_backward2 = Run(
306
309
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
307
310
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
308
311
  needs_closure=True,
309
- func='booth', steps=50, loss=2e-7, merge_invariant=True,
312
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
310
313
  sphere_steps=2, sphere_loss=340,
311
314
  )
312
315
  FDM_forward3 = Run(
313
316
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
314
317
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
315
318
  needs_closure=True,
316
- func='booth', steps=50, loss=3e-7, merge_invariant=True,
319
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
317
320
  sphere_steps=2, sphere_loss=340,
318
321
  )
319
322
  FDM_backward3 = Run(
320
323
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
321
324
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
322
325
  needs_closure=True,
323
- func='booth', steps=50, loss=3e-7, merge_invariant=True,
326
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
324
327
  sphere_steps=2, sphere_loss=340,
325
328
  )
326
329
  FDM_central4 = Run(
327
330
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
328
331
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
329
332
  needs_closure=True,
330
- func='booth', steps=50, loss=2e-8, merge_invariant=True,
333
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
331
334
  sphere_steps=2, sphere_loss=340,
332
335
  )
333
336
 
@@ -374,6 +377,21 @@ RandomizedFDM_central4 = Run(
374
377
  func='booth', steps=50, loss=10, merge_invariant=True,
375
378
  sphere_steps=100, sphere_loss=450,
376
379
  )
380
+ RandomizedFDM_forward4 = Run(
381
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
382
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
383
+ needs_closure=True,
384
+ func='booth', steps=50, loss=10, merge_invariant=True,
385
+ sphere_steps=100, sphere_loss=450,
386
+ )
387
+ RandomizedFDM_forward5 = Run(
388
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
389
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
390
+ needs_closure=True,
391
+ func='booth', steps=50, loss=10, merge_invariant=True,
392
+ sphere_steps=100, sphere_loss=450,
393
+ )
394
+
377
395
 
378
396
  RandomizedFDM_4samples = Run(
379
397
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
@@ -455,25 +473,11 @@ Backtracking = Run(
455
473
  func='booth', steps=50, loss=0, merge_invariant=True,
456
474
  sphere_steps=2, sphere_loss=0,
457
475
  )
458
- Backtracking_try_negative = Run(
459
- func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
460
- sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
461
- needs_closure=True,
462
- func='booth', steps=50, loss=1e-9, merge_invariant=True,
463
- sphere_steps=2, sphere_loss=1e-10,
464
- )
465
476
  AdaptiveBacktracking = Run(
466
477
  func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
467
478
  sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
468
479
  needs_closure=True,
469
- func='booth', steps=50, loss=0, merge_invariant=True,
470
- sphere_steps=2, sphere_loss=0,
471
- )
472
- AdaptiveBacktracking_try_negative = Run(
473
- func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
474
- sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
475
- needs_closure=True,
476
- func='booth', steps=50, loss=1e-8, merge_invariant=True,
480
+ func='booth', steps=50, loss=1e-11, merge_invariant=True,
477
481
  sphere_steps=2, sphere_loss=1e-10,
478
482
  )
479
483
  # ----------------------------- line_search/scipy ---------------------------- #
@@ -494,15 +498,6 @@ StrongWolfe = Run(
494
498
  sphere_steps=2, sphere_loss=0,
495
499
  )
496
500
 
497
- # ------------------------- line_search/trust_region ------------------------- #
498
- TrustRegion = Run(
499
- func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
500
- sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
501
- needs_closure=True,
502
- func='booth', steps=50, loss=0.1, merge_invariant=True,
503
- sphere_steps=10, sphere_loss=1e-5,
504
- )
505
-
506
501
  # ----------------------------------- lr/lr ---------------------------------- #
507
502
  LR = Run(
508
503
  func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
@@ -587,8 +582,8 @@ UpdateGradientSignConsistency = Run(
587
582
  sphere_steps=10, sphere_loss=2,
588
583
  )
589
584
  IntermoduleCautious = Run(
590
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
591
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
585
+ func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
586
+ sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
592
587
  needs_closure=False,
593
588
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
594
589
  sphere_steps=10, sphere_loss=0.1,
@@ -601,8 +596,8 @@ ScaleByGradCosineSimilarity = Run(
601
596
  sphere_steps=10, sphere_loss=0.1,
602
597
  )
603
598
  ScaleModulesByCosineSimilarity = Run(
604
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
605
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
599
+ func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
600
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
606
601
  needs_closure=False,
607
602
  func='booth', steps=50, loss=0.005, merge_invariant=True,
608
603
  sphere_steps=10, sphere_loss=0.1,
@@ -610,47 +605,69 @@ ScaleModulesByCosineSimilarity = Run(
610
605
 
611
606
  # ------------------------- momentum/matrix_momentum ------------------------- #
612
607
  MatrixMomentum_forward = Run(
613
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
614
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
608
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
609
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
615
610
  needs_closure=True,
616
611
  func='booth', steps=50, loss=0.05, merge_invariant=True,
617
- sphere_steps=10, sphere_loss=0,
612
+ sphere_steps=10, sphere_loss=0.01,
618
613
  )
619
614
  MatrixMomentum_forward = Run(
620
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
621
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
615
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
616
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
622
617
  needs_closure=True,
623
618
  func='booth', steps=50, loss=0.05, merge_invariant=True,
624
- sphere_steps=10, sphere_loss=0,
619
+ sphere_steps=10, sphere_loss=0.01,
625
620
  )
626
621
  MatrixMomentum_forward = Run(
627
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
628
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
622
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
623
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
629
624
  needs_closure=True,
630
625
  func='booth', steps=50, loss=0.05, merge_invariant=True,
631
- sphere_steps=10, sphere_loss=0,
626
+ sphere_steps=10, sphere_loss=0.01,
632
627
  )
633
628
 
634
629
  AdaptiveMatrixMomentum_forward = Run(
635
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
636
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
630
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
631
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
637
632
  needs_closure=True,
638
- func='booth', steps=50, loss=0.002, merge_invariant=True,
639
- sphere_steps=10, sphere_loss=0,
633
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
634
+ sphere_steps=10, sphere_loss=0.05,
640
635
  )
641
636
  AdaptiveMatrixMomentum_central = Run(
642
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
643
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
637
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
638
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
644
639
  needs_closure=True,
645
- func='booth', steps=50, loss=0.002, merge_invariant=True,
646
- sphere_steps=10, sphere_loss=0,
640
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
641
+ sphere_steps=10, sphere_loss=0.05,
647
642
  )
648
643
  AdaptiveMatrixMomentum_autograd = Run(
649
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
650
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
644
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
645
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
651
646
  needs_closure=True,
652
- func='booth', steps=50, loss=0.002, merge_invariant=True,
653
- sphere_steps=10, sphere_loss=0,
647
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
648
+ sphere_steps=10, sphere_loss=0.05,
649
+ )
650
+
651
+ StochasticAdaptiveMatrixMomentum_forward = Run(
652
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
653
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
654
+ needs_closure=True,
655
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
656
+ sphere_steps=10, sphere_loss=0.05,
657
+ )
658
+ StochasticAdaptiveMatrixMomentum_central = Run(
659
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
660
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
661
+ needs_closure=True,
662
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
663
+ sphere_steps=10, sphere_loss=0.05,
664
+ )
665
+ StochasticAdaptiveMatrixMomentum_autograd = Run(
666
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
667
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
668
+ needs_closure=True,
669
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
670
+ sphere_steps=10, sphere_loss=0.05,
654
671
  )
655
672
 
656
673
  # EMA, momentum are covered by test_identical
@@ -677,15 +694,15 @@ UpdateSign = Run(
677
694
  sphere_steps=10, sphere_loss=0,
678
695
  )
679
696
  GradAccumulation = Run(
680
- func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.05), 10), ),
681
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.5), 10), ),
697
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
698
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
682
699
  needs_closure=False,
683
700
  func='booth', steps=50, loss=25, merge_invariant=True,
684
701
  sphere_steps=20, sphere_loss=1e-11,
685
702
  )
686
703
  NegateOnLossIncrease = Run(
687
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
688
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
704
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
705
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
689
706
  needs_closure=True,
690
707
  func='booth', steps=50, loss=0.1, merge_invariant=True,
691
708
  sphere_steps=20, sphere_loss=0.001,
@@ -693,7 +710,7 @@ NegateOnLossIncrease = Run(
693
710
  # -------------------------------- misc/switch ------------------------------- #
694
711
  Alternate = Run(
695
712
  func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
696
- sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
713
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
697
714
  needs_closure=False,
698
715
  func='booth', steps=50, loss=1, merge_invariant=True,
699
716
  sphere_steps=20, sphere_loss=20,
@@ -734,24 +751,24 @@ Shampoo = Run(
734
751
 
735
752
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
736
753
  BFGS = Run(
737
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
738
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
754
+ func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
755
+ sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
739
756
  needs_closure=True,
740
- func='rosen', steps=50, loss=0, merge_invariant=True,
741
- sphere_steps=10, sphere_loss=0,
757
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
758
+ sphere_steps=10, sphere_loss=1e-10,
742
759
  )
743
760
  SR1 = Run(
744
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
745
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
761
+ func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
762
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
746
763
  needs_closure=True,
747
764
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
748
765
  sphere_steps=10, sphere_loss=0,
749
766
  )
750
767
  SSVM = Run(
751
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
752
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
768
+ func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
769
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
753
770
  needs_closure=True,
754
- func='rosen', steps=50, loss=1e-10, merge_invariant=True,
771
+ func='rosen', steps=50, loss=0.2, merge_invariant=True,
755
772
  sphere_steps=10, sphere_loss=0,
756
773
  )
757
774
 
@@ -766,26 +783,26 @@ LBFGS = Run(
766
783
 
767
784
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
768
785
  LSR1 = Run(
769
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
770
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
786
+ func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
787
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
771
788
  needs_closure=True,
772
789
  func='rosen', steps=50, loss=0, merge_invariant=True,
773
790
  sphere_steps=10, sphere_loss=0,
774
791
  )
775
792
 
776
- # ---------------------------- quasi_newton/olbfgs --------------------------- #
777
- OnlineLBFGS = Run(
778
- func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
779
- sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
780
- needs_closure=True,
781
- func='rosen', steps=50, loss=0, merge_invariant=True,
782
- sphere_steps=10, sphere_loss=0,
783
- )
793
+ # # ---------------------------- quasi_newton/olbfgs --------------------------- #
794
+ # OnlineLBFGS = Run(
795
+ # func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
796
+ # sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
797
+ # needs_closure=True,
798
+ # func='rosen', steps=50, loss=0, merge_invariant=True,
799
+ # sphere_steps=10, sphere_loss=0,
800
+ # )
784
801
 
785
802
  # ---------------------------- second_order/newton --------------------------- #
786
803
  Newton = Run(
787
- func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
788
- sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
804
+ func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
805
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
789
806
  needs_closure=True,
790
807
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
791
808
  sphere_steps=2, sphere_loss=1e-9,
@@ -793,8 +810,8 @@ Newton = Run(
793
810
 
794
811
  # --------------------------- second_order/newton_cg -------------------------- #
795
812
  NewtonCG = Run(
796
- func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
797
- sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
813
+ func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
814
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
798
815
  needs_closure=True,
799
816
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
800
817
  sphere_steps=2, sphere_loss=3e-4,
@@ -802,11 +819,11 @@ NewtonCG = Run(
802
819
 
803
820
  # ---------------------------- smoothing/gaussian ---------------------------- #
804
821
  GaussianHomotopy = Run(
805
- func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
806
- sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
822
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
823
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
807
824
  needs_closure=True,
808
- func='booth', steps=20, loss=0.1, merge_invariant=True,
809
- sphere_steps=10, sphere_loss=200,
825
+ func='booth', steps=20, loss=0.01, merge_invariant=True,
826
+ sphere_steps=10, sphere_loss=1,
810
827
  )
811
828
 
812
829
  # ---------------------------- smoothing/laplacian --------------------------- #
@@ -860,7 +877,7 @@ SophiaH = Run(
860
877
  sphere_steps=10, sphere_loss=40,
861
878
  )
862
879
 
863
- # -------------------------- optimizers/higher_order ------------------------- #
880
+ # -------------------------- higher_order ------------------------- #
864
881
  HigherOrderNewton = Run(
865
882
  func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
866
883
  sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
@@ -869,15 +886,33 @@ HigherOrderNewton = Run(
869
886
  sphere_steps=1, sphere_loss=1e-10,
870
887
  )
871
888
 
889
+ # ---------------------------- optimizers/ladagrad --------------------------- #
890
+ LMAdagrad = Run(
891
+ func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
892
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
893
+ needs_closure=False,
894
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
895
+ sphere_steps=20, sphere_loss=1e-9,
896
+ )
897
+
898
+ # ------------------------------ optimizers/adan ----------------------------- #
899
+ Adan = Run(
900
+ func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
901
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
902
+ needs_closure=False,
903
+ func='booth', steps=50, loss=60, merge_invariant=True,
904
+ sphere_steps=20, sphere_loss=60,
905
+ )
906
+
872
907
  # ------------------------------------ CGs ----------------------------------- #
873
- for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
908
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.DYHS, tz.m.ProjectedGradientMethod):
874
909
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
875
910
  # but also test 10 to make sure it doesn't explode after converging
876
911
  Run(
877
912
  func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
878
913
  sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
879
914
  needs_closure=True,
880
- func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=False, # strong wolfe adds float imprecision
915
+ func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
881
916
  sphere_steps=sphere_steps_, sphere_loss=0,
882
917
  )
883
918
 
@@ -885,17 +920,22 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
885
920
  # stability test
886
921
  for QN in (
887
922
  tz.m.BFGS,
923
+ partial(tz.m.BFGS, inverse=False),
888
924
  tz.m.SR1,
925
+ partial(tz.m.SR1, inverse=False),
889
926
  tz.m.DFP,
927
+ partial(tz.m.DFP, inverse=False),
890
928
  tz.m.BroydenGood,
929
+ partial(tz.m.BroydenGood, inverse=False),
891
930
  tz.m.BroydenBad,
931
+ partial(tz.m.BroydenBad, inverse=False),
892
932
  tz.m.Greenstadt1,
893
933
  tz.m.Greenstadt2,
894
- tz.m.ColumnUpdatingMethod,
934
+ tz.m.ICUM,
895
935
  tz.m.ThomasOptimalMethod,
896
936
  tz.m.FletcherVMM,
897
937
  tz.m.Horisho,
898
- lambda scale_first: tz.m.Horisho(scale_first=scale_first, inner=tz.m.GradientCorrection()),
938
+ partial(tz.m.Horisho, inner=tz.m.GradientCorrection()),
899
939
  tz.m.Pearson,
900
940
  tz.m.ProjectedNewtonRaphson,
901
941
  tz.m.PSB,
@@ -903,10 +943,10 @@ for QN in (
903
943
  tz.m.SSVM,
904
944
  ):
905
945
  Run(
906
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
907
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
946
+ func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
947
+ sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
908
948
  needs_closure=True,
909
- func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
949
+ func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
910
950
  sphere_steps=10, sphere_loss=1e-20,
911
951
  )
912
952
 
@@ -977,22 +977,23 @@ def test_rademacher_like(big_tl: TensorList):
977
977
 
978
978
  @pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
979
979
  def test_sample_like(simple_tl: TensorList, dist):
980
- eps_scalar = 2.0
981
- result_tl_scalar = simple_tl.sample_like(eps_scalar, distribution=dist)
980
+ eps_scalar = 1
981
+ result_tl_scalar = simple_tl.sample_like(distribution=dist)
982
982
  assert isinstance(result_tl_scalar, TensorList)
983
983
  assert result_tl_scalar.shape == simple_tl.shape
984
984
 
985
- eps_list = [0.5, 1.0, 1.5]
986
- result_tl_list = simple_tl.sample_like(eps_list, distribution=dist)
985
+ eps_list = [1.0,]
986
+ result_tl_list = simple_tl.sample_like(distribution=dist)
987
987
  assert isinstance(result_tl_list, TensorList)
988
988
  assert result_tl_list.shape == simple_tl.shape
989
989
 
990
990
  # Basic checks based on distribution
991
991
  if dist == 'uniform':
992
- assert all(torch.all((t >= -eps_scalar/2) & (t <= eps_scalar/2)) for t in result_tl_scalar)
993
- assert all(torch.all((t >= -e/2) & (t <= e/2)) for t, e in zip(result_tl_list, eps_list))
992
+ assert all(torch.all((t >= -eps_scalar) & (t <= eps_scalar)) for t in result_tl_scalar)
993
+ assert all(torch.all((t >= -e) & (t <= e)) for t, e in zip(result_tl_list, eps_list))
994
994
  elif dist == 'sphere':
995
- assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
995
+ # assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
996
+ pass
996
997
  # Cannot check list version easily
997
998
  elif dist == 'rademacher':
998
999
  assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
@@ -156,6 +156,7 @@ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
156
156
  for k,v in v1.__dict__.items():
157
157
  if not k.startswith('__'):
158
158
  # if k == 'post_step_hooks': continue
159
+ if k == 'storage': continue
159
160
  if k == 'update' and clone_update:
160
161
  if v1.update is None or v2.update is None:
161
162
  assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
@@ -0,0 +1,4 @@
1
+ from . import core, optim, utils
2
+ from .core import Modular
3
+ from .utils import set_compilation
4
+ from . import modules as m
@@ -0,0 +1,2 @@
1
+ from .module import Chain, Chainable, Modular, Module, Var, maybe_chain
2
+ from .transform import Target, TensorwiseTransform, Transform, apply_transform