torchzero 0.3.9__tar.gz → 0.3.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. {torchzero-0.3.9 → torchzero-0.3.11}/PKG-INFO +76 -51
  2. {torchzero-0.3.9 → torchzero-0.3.11}/README.md +75 -50
  3. {torchzero-0.3.9 → torchzero-0.3.11}/docs/source/conf.py +6 -4
  4. torchzero-0.3.11/docs/source/docstring template.py +46 -0
  5. {torchzero-0.3.9 → torchzero-0.3.11}/pyproject.toml +2 -2
  6. {torchzero-0.3.9 → torchzero-0.3.11}/tests/test_identical.py +2 -3
  7. {torchzero-0.3.9 → torchzero-0.3.11}/tests/test_opts.py +115 -68
  8. {torchzero-0.3.9 → torchzero-0.3.11}/tests/test_tensorlist.py +2 -2
  9. {torchzero-0.3.9 → torchzero-0.3.11}/tests/test_vars.py +62 -61
  10. torchzero-0.3.11/torchzero/core/__init__.py +2 -0
  11. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/core/module.py +185 -53
  12. torchzero-0.3.11/torchzero/core/transform.py +420 -0
  13. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/__init__.py +3 -1
  14. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/clipping/clipping.py +120 -23
  15. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/clipping/ema_clipping.py +37 -22
  16. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/clipping/growth_clipping.py +20 -21
  17. torchzero-0.3.11/torchzero/modules/experimental/__init__.py +41 -0
  18. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/absoap.py +53 -156
  19. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/adadam.py +22 -15
  20. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/adamY.py +21 -25
  21. torchzero-0.3.11/torchzero/modules/experimental/adam_lambertw.py +149 -0
  22. torchzero-0.3.9/torchzero/modules/line_search/trust_region.py → torchzero-0.3.11/torchzero/modules/experimental/adaptive_step_size.py +37 -8
  23. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/adasoap.py +24 -129
  24. torchzero-0.3.11/torchzero/modules/experimental/cosine.py +214 -0
  25. torchzero-0.3.11/torchzero/modules/experimental/cubic_adam.py +97 -0
  26. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/curveball.py +12 -12
  27. {torchzero-0.3.9/torchzero/modules/projections → torchzero-0.3.11/torchzero/modules/experimental}/dct.py +11 -11
  28. torchzero-0.3.11/torchzero/modules/experimental/eigendescent.py +120 -0
  29. torchzero-0.3.11/torchzero/modules/experimental/etf.py +195 -0
  30. torchzero-0.3.11/torchzero/modules/experimental/exp_adam.py +113 -0
  31. torchzero-0.3.11/torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  32. {torchzero-0.3.9/torchzero/modules/projections → torchzero-0.3.11/torchzero/modules/experimental}/fft.py +10 -10
  33. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/gradmin.py +2 -2
  34. torchzero-0.3.11/torchzero/modules/experimental/hnewton.py +85 -0
  35. {torchzero-0.3.9/torchzero/modules/quasi_newton → torchzero-0.3.11/torchzero/modules}/experimental/modular_lbfgs.py +49 -50
  36. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/newton_solver.py +11 -11
  37. torchzero-0.3.11/torchzero/modules/experimental/newtonnewton.py +92 -0
  38. torchzero-0.3.11/torchzero/modules/experimental/parabolic_search.py +220 -0
  39. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  40. torchzero-0.3.9/torchzero/modules/projections/structural.py → torchzero-0.3.11/torchzero/modules/experimental/structural_projections.py +12 -54
  41. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  42. torchzero-0.3.11/torchzero/modules/experimental/tensor_adagrad.py +42 -0
  43. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/functional.py +12 -2
  44. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/grad_approximation/fdm.py +31 -4
  45. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  46. torchzero-0.3.11/torchzero/modules/grad_approximation/grad_approximator.py +111 -0
  47. torchzero-0.3.11/torchzero/modules/grad_approximation/rfdm.py +519 -0
  48. torchzero-0.3.11/torchzero/modules/higher_order/__init__.py +1 -0
  49. torchzero-0.3.11/torchzero/modules/higher_order/higher_order_newton.py +319 -0
  50. torchzero-0.3.11/torchzero/modules/line_search/__init__.py +5 -0
  51. torchzero-0.3.11/torchzero/modules/line_search/adaptive.py +99 -0
  52. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/line_search/backtracking.py +75 -31
  53. torchzero-0.3.11/torchzero/modules/line_search/line_search.py +239 -0
  54. torchzero-0.3.11/torchzero/modules/line_search/polynomial.py +233 -0
  55. torchzero-0.3.11/torchzero/modules/line_search/scipy.py +52 -0
  56. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/line_search/strong_wolfe.py +52 -36
  57. torchzero-0.3.11/torchzero/modules/misc/__init__.py +27 -0
  58. torchzero-0.3.11/torchzero/modules/misc/debug.py +48 -0
  59. torchzero-0.3.11/torchzero/modules/misc/escape.py +60 -0
  60. torchzero-0.3.11/torchzero/modules/misc/gradient_accumulation.py +70 -0
  61. torchzero-0.3.11/torchzero/modules/misc/misc.py +316 -0
  62. torchzero-0.3.11/torchzero/modules/misc/multistep.py +158 -0
  63. torchzero-0.3.11/torchzero/modules/misc/regularization.py +171 -0
  64. torchzero-0.3.11/torchzero/modules/misc/split.py +103 -0
  65. {torchzero-0.3.9/torchzero/modules/ops → torchzero-0.3.11/torchzero/modules/misc}/switch.py +48 -7
  66. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/momentum/__init__.py +1 -1
  67. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/momentum/averaging.py +25 -10
  68. torchzero-0.3.11/torchzero/modules/momentum/cautious.py +256 -0
  69. torchzero-0.3.11/torchzero/modules/momentum/ema.py +224 -0
  70. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/momentum/experimental.py +21 -13
  71. torchzero-0.3.11/torchzero/modules/momentum/matrix_momentum.py +193 -0
  72. torchzero-0.3.11/torchzero/modules/momentum/momentum.py +64 -0
  73. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/ops/__init__.py +3 -31
  74. torchzero-0.3.11/torchzero/modules/ops/accumulate.py +91 -0
  75. torchzero-0.3.11/torchzero/modules/ops/binary.py +286 -0
  76. torchzero-0.3.11/torchzero/modules/ops/multi.py +198 -0
  77. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/ops/reduce.py +31 -23
  78. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/ops/unary.py +37 -21
  79. torchzero-0.3.11/torchzero/modules/ops/utility.py +120 -0
  80. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/__init__.py +12 -3
  81. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/adagrad.py +48 -29
  82. torchzero-0.3.11/torchzero/modules/optimizers/adahessian.py +223 -0
  83. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/adam.py +35 -37
  84. torchzero-0.3.11/torchzero/modules/optimizers/adan.py +110 -0
  85. torchzero-0.3.11/torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  86. torchzero-0.3.11/torchzero/modules/optimizers/esgd.py +171 -0
  87. torchzero-0.3.11/torchzero/modules/optimizers/ladagrad.py +183 -0
  88. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/lion.py +4 -4
  89. torchzero-0.3.11/torchzero/modules/optimizers/mars.py +91 -0
  90. torchzero-0.3.11/torchzero/modules/optimizers/msam.py +186 -0
  91. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/muon.py +32 -7
  92. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/orthograd.py +4 -5
  93. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/rmsprop.py +19 -19
  94. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/rprop.py +89 -52
  95. torchzero-0.3.11/torchzero/modules/optimizers/sam.py +163 -0
  96. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/shampoo.py +55 -27
  97. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/optimizers/soap.py +40 -37
  98. torchzero-0.3.11/torchzero/modules/optimizers/sophia_h.py +186 -0
  99. torchzero-0.3.11/torchzero/modules/projections/__init__.py +3 -0
  100. torchzero-0.3.11/torchzero/modules/projections/cast.py +51 -0
  101. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/projections/galore.py +4 -2
  102. torchzero-0.3.11/torchzero/modules/projections/projection.py +338 -0
  103. torchzero-0.3.11/torchzero/modules/quasi_newton/__init__.py +46 -0
  104. torchzero-0.3.11/torchzero/modules/quasi_newton/cg.py +369 -0
  105. torchzero-0.3.11/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  106. torchzero-0.3.11/torchzero/modules/quasi_newton/lbfgs.py +286 -0
  107. torchzero-0.3.11/torchzero/modules/quasi_newton/lsr1.py +218 -0
  108. torchzero-0.3.11/torchzero/modules/quasi_newton/quasi_newton.py +1331 -0
  109. torchzero-0.3.11/torchzero/modules/quasi_newton/trust_region.py +397 -0
  110. torchzero-0.3.11/torchzero/modules/second_order/__init__.py +3 -0
  111. torchzero-0.3.11/torchzero/modules/second_order/newton.py +338 -0
  112. torchzero-0.3.11/torchzero/modules/second_order/newton_cg.py +374 -0
  113. torchzero-0.3.11/torchzero/modules/second_order/nystrom.py +271 -0
  114. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/smoothing/gaussian.py +55 -21
  115. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/smoothing/laplacian.py +20 -12
  116. torchzero-0.3.11/torchzero/modules/step_size/__init__.py +2 -0
  117. torchzero-0.3.11/torchzero/modules/step_size/adaptive.py +122 -0
  118. torchzero-0.3.11/torchzero/modules/step_size/lr.py +154 -0
  119. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/weight_decay/__init__.py +1 -1
  120. torchzero-0.3.11/torchzero/modules/weight_decay/weight_decay.py +168 -0
  121. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/wrappers/optim_wrapper.py +40 -12
  122. torchzero-0.3.11/torchzero/optim/wrappers/directsearch.py +281 -0
  123. torchzero-0.3.11/torchzero/optim/wrappers/fcmaes.py +105 -0
  124. torchzero-0.3.11/torchzero/optim/wrappers/mads.py +89 -0
  125. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/wrappers/nevergrad.py +20 -5
  126. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/wrappers/nlopt.py +28 -14
  127. torchzero-0.3.11/torchzero/optim/wrappers/optuna.py +70 -0
  128. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/wrappers/scipy.py +167 -16
  129. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/__init__.py +3 -7
  130. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/derivatives.py +5 -4
  131. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/linalg/__init__.py +1 -1
  132. torchzero-0.3.11/torchzero/utils/linalg/solve.py +408 -0
  133. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/numberlist.py +2 -0
  134. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/optimizer.py +55 -74
  135. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/python_tools.py +27 -4
  136. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/tensorlist.py +40 -28
  137. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero.egg-info/PKG-INFO +76 -51
  138. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero.egg-info/SOURCES.txt +48 -20
  139. torchzero-0.3.9/torchzero/core/__init__.py +0 -3
  140. torchzero-0.3.9/torchzero/core/preconditioner.py +0 -138
  141. torchzero-0.3.9/torchzero/core/transform.py +0 -252
  142. torchzero-0.3.9/torchzero/modules/experimental/__init__.py +0 -15
  143. torchzero-0.3.9/torchzero/modules/experimental/algebraic_newton.py +0 -145
  144. torchzero-0.3.9/torchzero/modules/experimental/soapy.py +0 -290
  145. torchzero-0.3.9/torchzero/modules/experimental/spectral.py +0 -288
  146. torchzero-0.3.9/torchzero/modules/experimental/structured_newton.py +0 -111
  147. torchzero-0.3.9/torchzero/modules/experimental/tropical_newton.py +0 -136
  148. torchzero-0.3.9/torchzero/modules/grad_approximation/grad_approximator.py +0 -66
  149. torchzero-0.3.9/torchzero/modules/grad_approximation/rfdm.py +0 -259
  150. torchzero-0.3.9/torchzero/modules/line_search/__init__.py +0 -5
  151. torchzero-0.3.9/torchzero/modules/line_search/line_search.py +0 -181
  152. torchzero-0.3.9/torchzero/modules/line_search/scipy.py +0 -37
  153. torchzero-0.3.9/torchzero/modules/lr/__init__.py +0 -2
  154. torchzero-0.3.9/torchzero/modules/lr/lr.py +0 -59
  155. torchzero-0.3.9/torchzero/modules/lr/step_size.py +0 -97
  156. torchzero-0.3.9/torchzero/modules/momentum/cautious.py +0 -181
  157. torchzero-0.3.9/torchzero/modules/momentum/ema.py +0 -173
  158. torchzero-0.3.9/torchzero/modules/momentum/matrix_momentum.py +0 -124
  159. torchzero-0.3.9/torchzero/modules/momentum/momentum.py +0 -43
  160. torchzero-0.3.9/torchzero/modules/ops/accumulate.py +0 -65
  161. torchzero-0.3.9/torchzero/modules/ops/binary.py +0 -240
  162. torchzero-0.3.9/torchzero/modules/ops/debug.py +0 -25
  163. torchzero-0.3.9/torchzero/modules/ops/misc.py +0 -419
  164. torchzero-0.3.9/torchzero/modules/ops/multi.py +0 -137
  165. torchzero-0.3.9/torchzero/modules/ops/split.py +0 -75
  166. torchzero-0.3.9/torchzero/modules/ops/utility.py +0 -112
  167. torchzero-0.3.9/torchzero/modules/optimizers/sophia_h.py +0 -129
  168. torchzero-0.3.9/torchzero/modules/projections/__init__.py +0 -5
  169. torchzero-0.3.9/torchzero/modules/projections/projection.py +0 -244
  170. torchzero-0.3.9/torchzero/modules/quasi_newton/__init__.py +0 -7
  171. torchzero-0.3.9/torchzero/modules/quasi_newton/cg.py +0 -218
  172. torchzero-0.3.9/torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  173. torchzero-0.3.9/torchzero/modules/quasi_newton/lbfgs.py +0 -229
  174. torchzero-0.3.9/torchzero/modules/quasi_newton/lsr1.py +0 -174
  175. torchzero-0.3.9/torchzero/modules/quasi_newton/olbfgs.py +0 -196
  176. torchzero-0.3.9/torchzero/modules/quasi_newton/quasi_newton.py +0 -476
  177. torchzero-0.3.9/torchzero/modules/second_order/__init__.py +0 -3
  178. torchzero-0.3.9/torchzero/modules/second_order/newton.py +0 -147
  179. torchzero-0.3.9/torchzero/modules/second_order/newton_cg.py +0 -84
  180. torchzero-0.3.9/torchzero/modules/second_order/nystrom.py +0 -168
  181. torchzero-0.3.9/torchzero/modules/weight_decay/weight_decay.py +0 -52
  182. torchzero-0.3.9/torchzero/utils/linalg/solve.py +0 -169
  183. {torchzero-0.3.9 → torchzero-0.3.11}/LICENSE +0 -0
  184. {torchzero-0.3.9 → torchzero-0.3.11}/setup.cfg +0 -0
  185. {torchzero-0.3.9 → torchzero-0.3.11}/tests/test_module.py +0 -0
  186. {torchzero-0.3.9 → torchzero-0.3.11}/tests/test_utils_optimizer.py +0 -0
  187. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/__init__.py +0 -0
  188. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/clipping/__init__.py +0 -0
  189. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/grad_approximation/__init__.py +0 -0
  190. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/smoothing/__init__.py +0 -0
  191. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/modules/wrappers/__init__.py +0 -0
  192. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/__init__.py +0 -0
  193. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/utility/__init__.py +0 -0
  194. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/utility/split.py +0 -0
  195. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/optim/wrappers/__init__.py +0 -0
  196. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/compile.py +0 -0
  197. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/linalg/benchmark.py +0 -0
  198. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/linalg/matrix_funcs.py +0 -0
  199. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/linalg/orthogonalize.py +0 -0
  200. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/linalg/qr.py +0 -0
  201. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/linalg/svd.py +0 -0
  202. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/ops.py +0 -0
  203. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/optuna_tools.py +0 -0
  204. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/params.py +0 -0
  205. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero/utils/torch_tools.py +0 -0
  206. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero.egg-info/dependency_links.txt +0 -0
  207. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero.egg-info/requires.txt +0 -0
  208. {torchzero-0.3.9 → torchzero-0.3.11}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -45,8 +45,6 @@ Dynamic: license-file
45
45
 
46
46
  `torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
47
47
 
48
- NOTE: torchzero is in active development, currently docs are in a state of flux.
49
-
50
48
  ## Installation
51
49
 
52
50
  ```bash
@@ -113,31 +111,21 @@ for epoch in range(100):
113
111
  `torchzero` provides a huge number of various modules:
114
112
 
115
113
  * **Optimizers**: Optimization algorithms.
116
- * `Adam`.
117
- * `Shampoo`.
118
- * `SOAP` (my current recommendation).
119
- * `Muon`.
120
- * `SophiaH`.
121
- * `Adagrad` and `FullMatrixAdagrad`.
122
- * `Lion`.
123
- * `RMSprop`.
124
- * `OrthoGrad`.
125
- * `Rprop`.
114
+ * `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
126
115
 
127
116
  Additionally many other optimizers can be easily defined via modules:
128
117
  * Grams: `[tz.m.Adam(), tz.m.GradSign()]`
129
118
  * LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
130
119
  * Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
131
- * Full matrix version of any diagonal optimizer, like Adam: `tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9))`
120
+ * Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
132
121
  * Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
133
122
 
134
123
  * **Momentum**:
135
- * `NAG`: Nesterov Accelerated Gradient.
136
124
  * `HeavyBall`: Classic momentum (Polyak's momentum).
125
+ * `NAG`: Nesterov Accelerated Gradient.
137
126
  * `EMA`: Exponential moving average.
138
- * `Averaging` (`Medianveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
127
+ * `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
139
128
  * `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
140
- * `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
141
129
 
142
130
  * **Stabilization**: Gradient stabilization techniques.
143
131
  * `ClipNorm`: Clips gradient L2 norm.
@@ -154,31 +142,42 @@ for epoch in range(100):
154
142
 
155
143
  * **Second order**: Second order methods.
156
144
  * `Newton`: Classic Newton's method.
157
- * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
145
+ * `InverseFreeNewton`: Inverse-free version of Newton's method.
146
+ * `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
147
+ * `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
158
148
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
- * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
149
+ * `NystromPCG`: NewtonCG with Nyström preconditioning.
150
+ * `HigherOrderNewton`: Higher order Newton's method with trust region.
160
151
 
161
152
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
153
  * `LBFGS`: Limited-memory BFGS.
163
154
  * `LSR1`: Limited-memory SR1.
164
155
  * `OnlineLBFGS`: Online LBFGS.
165
- * `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
166
- * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
156
+ * `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
157
+ * `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
158
+ * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
159
+
160
+ * **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
161
+ * `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
162
+ * `CubicRegularization`: Cubic regularization, works better with exact hessian.
167
163
 
168
164
  * **Line Search**:
169
165
  * `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
170
166
  * `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
171
167
  * `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
172
- * `TrustRegion`: First order trust region method.
173
168
 
174
169
  * **Learning Rate**:
175
170
  * `LR`: Controls learning rate and adds support for LR schedulers.
176
- * `PolyakStepSize`: Polyak's method.
177
- * `Warmup`: Learning rate warmup.
171
+ * `PolyakStepSize`: Polyak's subgradient method.
172
+ * `BarzilaiBorwein`: Barzilai-Borwein step-size.
173
+ * `Warmup`, `WarmupNormCLip`: Learning rate warmup.
178
174
 
179
175
  * **Projections**: This can implement things like GaLore but I haven't done that yet.
180
- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
181
- * `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
176
+ <!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
177
+ * `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
178
+ This is WIP
179
+ * `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
180
+ * `ViewAsReal`: put if you have complex paramters.
182
181
 
183
182
  * **Smoothing**: Smoothing-based optimization methods.
184
183
  * `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
@@ -194,6 +193,8 @@ for epoch in range(100):
194
193
 
195
194
  * **Experimental**: various horrible atrocities
196
195
 
196
+ A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
197
+
197
198
  ## Advanced Usage
198
199
 
199
200
  ### Closure
@@ -312,20 +313,21 @@ not in the module itself. Also both per-parameter settings and state are stored
312
313
 
313
314
  ```python
314
315
  import torch
315
- from torchzero.core import Module, Vars
316
+ from torchzero.core import Module, Var
316
317
 
317
318
  class HeavyBall(Module):
318
319
  def __init__(self, momentum: float = 0.9, dampening: float = 0):
319
320
  defaults = dict(momentum=momentum, dampening=dampening)
320
321
  super().__init__(defaults)
321
322
 
322
- def step(self, vars: Vars):
323
- # a module takes a Vars object, modifies it or creates a new one, and returns it
324
- # Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
323
+ def step(self, var: Var):
324
+ # Var object holds all attributes used for optimization - parameters, gradient, update, etc.
325
+ # a module takes a Var object, modifies it or creates a new one, and returns it
326
+ # Var has a bunch of attributes, including parameters, gradients, update, closure, loss
325
327
  # for now we are only interested in update, and we will apply the heavyball rule to it.
326
328
 
327
- params = vars.params
328
- update = vars.get_update() # list of tensors
329
+ params = var.params
330
+ update = var.get_update() # list of tensors
329
331
 
330
332
  exp_avg_list = []
331
333
  for p, u in zip(params, update):
@@ -346,34 +348,57 @@ class HeavyBall(Module):
346
348
  # and it is part of self.state
347
349
  exp_avg_list.append(buf.clone())
348
350
 
349
- # set new update to vars
350
- vars.update = exp_avg_list
351
- return vars
351
+ # set new update to var
352
+ var.update = exp_avg_list
353
+ return var
352
354
  ```
353
355
 
354
- There are a some specialized base modules that make it much easier to implement some specific things.
356
+ More in-depth guide will be available in the documentation in the future.
357
+
358
+ ## Other stuff
355
359
 
356
- * `GradApproximator` for gradient approximations
357
- * `LineSearch` for line searches
358
- * `Preconditioner` for preconditioners
359
- * `Projection` for projections like GaLore or into fourier domain.
360
- * `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
361
- * `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
360
+ There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
362
361
 
363
- The documentation on how to actually use them is to write itself in the near future.
362
+ ---
364
363
 
365
- ## License
364
+ ### Scipy
366
365
 
367
- This project is licensed under the MIT License
366
+ #### torchzero.optim.wrappers.scipy.ScipyMinimize
368
367
 
369
- ## Project Links
368
+ A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
370
369
 
371
- TODO (there are docs but from very old version)
370
+ #### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
372
371
 
373
- ## Other stuff
372
+ Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
373
+
374
+ ---
375
+
376
+ ### NLOpt
377
+
378
+ #### torchzero.optim.wrappers.nlopt.NLOptWrapper
374
379
 
375
- There are also wrappers providing `torch.optim.Optimizer` interface for for `scipy.optimize`, NLOpt and Nevergrad.
380
+ A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
381
+
382
+ ---
383
+
384
+ ### Nevergrad
385
+
386
+ #### torchzero.optim.wrappers.nevergrad.NevergradWrapper
387
+
388
+ A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
389
+
390
+ ---
391
+
392
+ ### fast-cma-es
393
+
394
+ #### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
395
+
396
+ A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
397
+
398
+ # License
399
+
400
+ This project is licensed under the MIT License
376
401
 
377
- They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
402
+ # Project Links
378
403
 
379
- Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
404
+ The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
@@ -6,8 +6,6 @@
6
6
 
7
7
  `torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
8
8
 
9
- NOTE: torchzero is in active development, currently docs are in a state of flux.
10
-
11
9
  ## Installation
12
10
 
13
11
  ```bash
@@ -74,31 +72,21 @@ for epoch in range(100):
74
72
  `torchzero` provides a huge number of various modules:
75
73
 
76
74
  * **Optimizers**: Optimization algorithms.
77
- * `Adam`.
78
- * `Shampoo`.
79
- * `SOAP` (my current recommendation).
80
- * `Muon`.
81
- * `SophiaH`.
82
- * `Adagrad` and `FullMatrixAdagrad`.
83
- * `Lion`.
84
- * `RMSprop`.
85
- * `OrthoGrad`.
86
- * `Rprop`.
75
+ * `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
87
76
 
88
77
  Additionally many other optimizers can be easily defined via modules:
89
78
  * Grams: `[tz.m.Adam(), tz.m.GradSign()]`
90
79
  * LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
91
80
  * Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
92
- * Full matrix version of any diagonal optimizer, like Adam: `tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9))`
81
+ * Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
93
82
  * Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
94
83
 
95
84
  * **Momentum**:
96
- * `NAG`: Nesterov Accelerated Gradient.
97
85
  * `HeavyBall`: Classic momentum (Polyak's momentum).
86
+ * `NAG`: Nesterov Accelerated Gradient.
98
87
  * `EMA`: Exponential moving average.
99
- * `Averaging` (`Medianveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
88
+ * `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
100
89
  * `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
101
- * `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
102
90
 
103
91
  * **Stabilization**: Gradient stabilization techniques.
104
92
  * `ClipNorm`: Clips gradient L2 norm.
@@ -115,31 +103,42 @@ for epoch in range(100):
115
103
 
116
104
  * **Second order**: Second order methods.
117
105
  * `Newton`: Classic Newton's method.
118
- * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
106
+ * `InverseFreeNewton`: Inverse-free version of Newton's method.
107
+ * `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
108
+ * `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
119
109
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
120
- * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
110
+ * `NystromPCG`: NewtonCG with Nyström preconditioning.
111
+ * `HigherOrderNewton`: Higher order Newton's method with trust region.
121
112
 
122
113
  * **Quasi-Newton**: Approximate second-order optimization methods.
123
114
  * `LBFGS`: Limited-memory BFGS.
124
115
  * `LSR1`: Limited-memory SR1.
125
116
  * `OnlineLBFGS`: Online LBFGS.
126
- * `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
127
- * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
117
+ * `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
118
+ * `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
119
+ * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
120
+
121
+ * **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
122
+ * `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
123
+ * `CubicRegularization`: Cubic regularization, works better with exact hessian.
128
124
 
129
125
  * **Line Search**:
130
126
  * `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
131
127
  * `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
132
128
  * `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
133
- * `TrustRegion`: First order trust region method.
134
129
 
135
130
  * **Learning Rate**:
136
131
  * `LR`: Controls learning rate and adds support for LR schedulers.
137
- * `PolyakStepSize`: Polyak's method.
138
- * `Warmup`: Learning rate warmup.
132
+ * `PolyakStepSize`: Polyak's subgradient method.
133
+ * `BarzilaiBorwein`: Barzilai-Borwein step-size.
134
+ * `Warmup`, `WarmupNormCLip`: Learning rate warmup.
139
135
 
140
136
  * **Projections**: This can implement things like GaLore but I haven't done that yet.
141
- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
142
- * `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
137
+ <!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
138
+ * `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
139
+ This is WIP
140
+ * `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
141
+ * `ViewAsReal`: put if you have complex paramters.
143
142
 
144
143
  * **Smoothing**: Smoothing-based optimization methods.
145
144
  * `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
@@ -155,6 +154,8 @@ for epoch in range(100):
155
154
 
156
155
  * **Experimental**: various horrible atrocities
157
156
 
157
+ A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
158
+
158
159
  ## Advanced Usage
159
160
 
160
161
  ### Closure
@@ -273,20 +274,21 @@ not in the module itself. Also both per-parameter settings and state are stored
273
274
 
274
275
  ```python
275
276
  import torch
276
- from torchzero.core import Module, Vars
277
+ from torchzero.core import Module, Var
277
278
 
278
279
  class HeavyBall(Module):
279
280
  def __init__(self, momentum: float = 0.9, dampening: float = 0):
280
281
  defaults = dict(momentum=momentum, dampening=dampening)
281
282
  super().__init__(defaults)
282
283
 
283
- def step(self, vars: Vars):
284
- # a module takes a Vars object, modifies it or creates a new one, and returns it
285
- # Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
284
+ def step(self, var: Var):
285
+ # Var object holds all attributes used for optimization - parameters, gradient, update, etc.
286
+ # a module takes a Var object, modifies it or creates a new one, and returns it
287
+ # Var has a bunch of attributes, including parameters, gradients, update, closure, loss
286
288
  # for now we are only interested in update, and we will apply the heavyball rule to it.
287
289
 
288
- params = vars.params
289
- update = vars.get_update() # list of tensors
290
+ params = var.params
291
+ update = var.get_update() # list of tensors
290
292
 
291
293
  exp_avg_list = []
292
294
  for p, u in zip(params, update):
@@ -307,34 +309,57 @@ class HeavyBall(Module):
307
309
  # and it is part of self.state
308
310
  exp_avg_list.append(buf.clone())
309
311
 
310
- # set new update to vars
311
- vars.update = exp_avg_list
312
- return vars
312
+ # set new update to var
313
+ var.update = exp_avg_list
314
+ return var
313
315
  ```
314
316
 
315
- There are a some specialized base modules that make it much easier to implement some specific things.
317
+ More in-depth guide will be available in the documentation in the future.
318
+
319
+ ## Other stuff
316
320
 
317
- * `GradApproximator` for gradient approximations
318
- * `LineSearch` for line searches
319
- * `Preconditioner` for preconditioners
320
- * `Projection` for projections like GaLore or into fourier domain.
321
- * `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
322
- * `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
321
+ There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
323
322
 
324
- The documentation on how to actually use them is to write itself in the near future.
323
+ ---
325
324
 
326
- ## License
325
+ ### Scipy
327
326
 
328
- This project is licensed under the MIT License
327
+ #### torchzero.optim.wrappers.scipy.ScipyMinimize
329
328
 
330
- ## Project Links
329
+ A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
331
330
 
332
- TODO (there are docs but from very old version)
331
+ #### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
333
332
 
334
- ## Other stuff
333
+ Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
334
+
335
+ ---
336
+
337
+ ### NLOpt
338
+
339
+ #### torchzero.optim.wrappers.nlopt.NLOptWrapper
335
340
 
336
- There are also wrappers providing `torch.optim.Optimizer` interface for for `scipy.optimize`, NLOpt and Nevergrad.
341
+ A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
342
+
343
+ ---
344
+
345
+ ### Nevergrad
346
+
347
+ #### torchzero.optim.wrappers.nevergrad.NevergradWrapper
348
+
349
+ A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
350
+
351
+ ---
352
+
353
+ ### fast-cma-es
354
+
355
+ #### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
356
+
357
+ A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
358
+
359
+ # License
360
+
361
+ This project is licensed under the MIT License
337
362
 
338
- They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
363
+ # Project Links
339
364
 
340
- Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
365
+ The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
@@ -6,10 +6,10 @@
6
6
  # -- Project information -----------------------------------------------------
7
7
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
8
  import sys, os
9
- #sys.path.insert(0, os.path.abspath('.../src'))
9
+ #sys.path.insert(0, os.path.abspath('.../src'))
10
10
 
11
11
  project = 'torchzero'
12
- copyright = '2024, Ivan Nikishev'
12
+ copyright = '2025, Ivan Nikishev'
13
13
  author = 'Ivan Nikishev'
14
14
 
15
15
  # -- General configuration ---------------------------------------------------
@@ -24,10 +24,12 @@ extensions = [
24
24
  'sphinx.ext.githubpages',
25
25
  'sphinx.ext.napoleon',
26
26
  'autoapi.extension',
27
+ "myst_nb",
28
+
27
29
  # 'sphinx_rtd_theme',
28
30
  ]
29
31
  autosummary_generate = True
30
- autoapi_dirs = ['../../src']
32
+ autoapi_dirs = ['../../torchzero']
31
33
  autoapi_type = "python"
32
34
  # autoapi_ignore = ["*/tensorlist.py"]
33
35
 
@@ -48,7 +50,7 @@ exclude_patterns = []
48
50
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49
51
 
50
52
  #html_theme = 'alabaster'
51
- html_theme = 'furo'
53
+ html_theme = 'sphinx_rtd_theme'
52
54
  html_static_path = ['_static']
53
55
 
54
56
 
@@ -0,0 +1,46 @@
1
+ class MyModule:
2
+ """[One-line summary of the class].
3
+
4
+ [A more detailed description of the class, explaining its purpose, how it
5
+ works, and its typical use cases. You can use multiple paragraphs.]
6
+
7
+ .. note::
8
+ [Optional: Add important notes, warnings, or usage guidelines here.
9
+ For example, you could mention if a closure is required, discuss
10
+ stability, or highlight performance characteristics. Use the `.. note::`
11
+ directive to make it stand out in the documentation.]
12
+
13
+ Args:
14
+ param1 (type, optional):
15
+ [Description of the first parameter. Use :code:`backticks` for
16
+ inline code like variable names or specific values like ``"autograd"``.
17
+ Explain what the parameter does.] Defaults to [value].
18
+ param2 (type):
19
+ [Description of a mandatory parameter (no "optional" or "Defaults to").]
20
+ **kwargs:
21
+ [If you accept keyword arguments, describe what they are used for.]
22
+
23
+ Examples:
24
+ [A title or short sentence describing the first example]:
25
+
26
+ .. code-block:: python
27
+
28
+ opt = tz.Modular(
29
+ model.parameters(),
30
+ ...
31
+ )
32
+
33
+ [A title or short sentence for a second, different example]:
34
+
35
+ .. code-block:: python
36
+
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ ...
40
+ )
41
+
42
+ References:
43
+ - [Optional: A citation for a relevant paper, book, or algorithm.]
44
+ - [Optional: A link to a blog post or website with more information.]
45
+
46
+ """
@@ -1,5 +1,5 @@
1
1
  # NEW VERSION TUTORIAL FOR MYSELF
2
- # STEP 1 - COMMIT NEW CHANGES BUT DON'T PUSH THEM YET
2
+ # STEP 1 - COMMIT NEW CHANGES AND PUSH THEM
3
3
  # STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
4
4
  # STEP 3 - CREATE TAG WITH THAT VERSION
5
5
  # STEP 4 - PUSH (SYNC) CHANGES
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.9"
16
+ version = "0.3.11"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
96
96
 
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
- # torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- # pytorch applies debiasing separately so it is applied before epsilo
99
+ torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
101
100
  tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
102
101
  tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
103
102
  tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
133
132
  tz.m.Debias2(beta=0.999),
134
133
  tz.m.Add(1e-8)]
135
134
  ))
136
- tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
135
+ tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
137
136
 
138
137
  _assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
139
138
  for fn in tz_fns: