torchzero 0.3.11__tar.gz → 0.3.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. torchzero-0.3.14/PKG-INFO +14 -0
  2. {torchzero-0.3.11 → torchzero-0.3.14}/pyproject.toml +1 -1
  3. {torchzero-0.3.11 → torchzero-0.3.14}/tests/test_opts.py +95 -76
  4. {torchzero-0.3.11 → torchzero-0.3.14}/tests/test_tensorlist.py +8 -7
  5. torchzero-0.3.14/torchzero/__init__.py +4 -0
  6. torchzero-0.3.14/torchzero/core/__init__.py +2 -0
  7. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/core/module.py +229 -72
  8. torchzero-0.3.14/torchzero/core/reformulation.py +65 -0
  9. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/core/transform.py +44 -24
  10. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/__init__.py +13 -5
  11. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/__init__.py +5 -2
  12. torchzero-0.3.14/torchzero/modules/adaptive/adagrad.py +356 -0
  13. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/adahessian.py +53 -52
  14. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/adam.py +0 -3
  15. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/adan.py +26 -40
  16. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/adaptive_heavyball.py +3 -6
  17. torchzero-0.3.14/torchzero/modules/adaptive/aegd.py +54 -0
  18. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/esgd.py +1 -1
  19. torchzero-0.3.11/torchzero/modules/optimizers/ladagrad.py → torchzero-0.3.14/torchzero/modules/adaptive/lmadagrad.py +42 -39
  20. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/mars.py +24 -36
  21. torchzero-0.3.14/torchzero/modules/adaptive/matrix_momentum.py +146 -0
  22. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/msam.py +14 -12
  23. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/muon.py +19 -20
  24. torchzero-0.3.14/torchzero/modules/adaptive/natural_gradient.py +175 -0
  25. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/rprop.py +0 -2
  26. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/sam.py +1 -1
  27. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/shampoo.py +8 -4
  28. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/soap.py +27 -50
  29. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/sophia_h.py +2 -3
  30. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/clipping/clipping.py +85 -92
  31. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/clipping/ema_clipping.py +5 -5
  32. torchzero-0.3.14/torchzero/modules/conjugate_gradient/__init__.py +11 -0
  33. {torchzero-0.3.11/torchzero/modules/quasi_newton → torchzero-0.3.14/torchzero/modules/conjugate_gradient}/cg.py +355 -369
  34. torchzero-0.3.14/torchzero/modules/experimental/__init__.py +18 -0
  35. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/dct.py +2 -2
  36. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/fft.py +2 -2
  37. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/gradmin.py +4 -3
  38. torchzero-0.3.14/torchzero/modules/experimental/l_infinity.py +111 -0
  39. torchzero-0.3.11/torchzero/modules/momentum/experimental.py → torchzero-0.3.14/torchzero/modules/experimental/momentum.py +3 -40
  40. torchzero-0.3.14/torchzero/modules/experimental/newton_solver.py +150 -0
  41. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/newtonnewton.py +27 -14
  42. torchzero-0.3.14/torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  43. torchzero-0.3.14/torchzero/modules/experimental/spsa1.py +93 -0
  44. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/structural_projections.py +1 -1
  45. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/functional.py +50 -14
  46. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/grad_approximation/__init__.py +1 -1
  47. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/grad_approximation/fdm.py +19 -20
  48. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  49. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  50. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/grad_approximation/rfdm.py +114 -175
  51. torchzero-0.3.14/torchzero/modules/higher_order/__init__.py +1 -0
  52. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/higher_order/higher_order_newton.py +31 -23
  53. torchzero-0.3.14/torchzero/modules/least_squares/__init__.py +1 -0
  54. torchzero-0.3.14/torchzero/modules/least_squares/gn.py +161 -0
  55. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/line_search/__init__.py +2 -2
  56. torchzero-0.3.14/torchzero/modules/line_search/_polyinterp.py +289 -0
  57. torchzero-0.3.14/torchzero/modules/line_search/adaptive.py +124 -0
  58. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/line_search/backtracking.py +83 -70
  59. torchzero-0.3.14/torchzero/modules/line_search/line_search.py +330 -0
  60. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/line_search/scipy.py +16 -4
  61. torchzero-0.3.14/torchzero/modules/line_search/strong_wolfe.py +375 -0
  62. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/__init__.py +8 -0
  63. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/debug.py +4 -4
  64. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/escape.py +9 -7
  65. torchzero-0.3.14/torchzero/modules/misc/gradient_accumulation.py +136 -0
  66. torchzero-0.3.14/torchzero/modules/misc/homotopy.py +59 -0
  67. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/misc.py +82 -15
  68. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/multistep.py +47 -11
  69. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/regularization.py +5 -9
  70. torchzero-0.3.14/torchzero/modules/misc/split.py +123 -0
  71. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/misc/switch.py +1 -1
  72. torchzero-0.3.14/torchzero/modules/momentum/__init__.py +10 -0
  73. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/momentum/averaging.py +3 -3
  74. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/momentum/cautious.py +42 -47
  75. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/momentum/momentum.py +35 -1
  76. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/__init__.py +9 -1
  77. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/binary.py +9 -8
  78. torchzero-0.3.11/torchzero/modules/momentum/ema.py → torchzero-0.3.14/torchzero/modules/ops/higher_level.py +10 -33
  79. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/multi.py +15 -15
  80. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/reduce.py +1 -1
  81. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/utility.py +12 -8
  82. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/projections/projection.py +4 -4
  83. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/quasi_newton/__init__.py +1 -16
  84. torchzero-0.3.14/torchzero/modules/quasi_newton/damping.py +105 -0
  85. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  86. torchzero-0.3.14/torchzero/modules/quasi_newton/lbfgs.py +342 -0
  87. torchzero-0.3.14/torchzero/modules/quasi_newton/lsr1.py +253 -0
  88. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  89. torchzero-0.3.14/torchzero/modules/restarts/__init__.py +7 -0
  90. torchzero-0.3.14/torchzero/modules/restarts/restars.py +253 -0
  91. torchzero-0.3.14/torchzero/modules/second_order/__init__.py +4 -0
  92. torchzero-0.3.14/torchzero/modules/second_order/multipoint.py +238 -0
  93. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/second_order/newton.py +133 -88
  94. torchzero-0.3.14/torchzero/modules/second_order/newton_cg.py +411 -0
  95. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/smoothing/__init__.py +1 -1
  96. torchzero-0.3.14/torchzero/modules/smoothing/sampling.py +300 -0
  97. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/step_size/__init__.py +1 -1
  98. torchzero-0.3.14/torchzero/modules/step_size/adaptive.py +387 -0
  99. torchzero-0.3.14/torchzero/modules/termination/__init__.py +14 -0
  100. torchzero-0.3.14/torchzero/modules/termination/termination.py +207 -0
  101. torchzero-0.3.14/torchzero/modules/trust_region/__init__.py +5 -0
  102. torchzero-0.3.14/torchzero/modules/trust_region/cubic_regularization.py +170 -0
  103. torchzero-0.3.14/torchzero/modules/trust_region/dogleg.py +92 -0
  104. torchzero-0.3.14/torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  105. torchzero-0.3.14/torchzero/modules/trust_region/trust_cg.py +99 -0
  106. torchzero-0.3.14/torchzero/modules/trust_region/trust_region.py +350 -0
  107. torchzero-0.3.14/torchzero/modules/variance_reduction/__init__.py +1 -0
  108. torchzero-0.3.14/torchzero/modules/variance_reduction/svrg.py +208 -0
  109. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/weight_decay/weight_decay.py +65 -64
  110. torchzero-0.3.14/torchzero/modules/zeroth_order/__init__.py +1 -0
  111. torchzero-0.3.14/torchzero/modules/zeroth_order/cd.py +122 -0
  112. torchzero-0.3.14/torchzero/optim/root.py +65 -0
  113. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/utility/split.py +8 -8
  114. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/directsearch.py +0 -1
  115. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/fcmaes.py +3 -2
  116. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/nlopt.py +0 -2
  117. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/optuna.py +2 -2
  118. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/scipy.py +81 -22
  119. torchzero-0.3.14/torchzero/utils/__init__.py +59 -0
  120. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/compile.py +1 -1
  121. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/derivatives.py +123 -111
  122. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/linalg/__init__.py +9 -2
  123. torchzero-0.3.14/torchzero/utils/linalg/linear_operator.py +329 -0
  124. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/linalg/matrix_funcs.py +2 -2
  125. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/linalg/orthogonalize.py +2 -1
  126. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/linalg/qr.py +2 -2
  127. torchzero-0.3.14/torchzero/utils/linalg/solve.py +480 -0
  128. torchzero-0.3.14/torchzero/utils/metrics.py +83 -0
  129. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/optimizer.py +2 -2
  130. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/python_tools.py +7 -0
  131. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/tensorlist.py +105 -34
  132. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/torch_tools.py +9 -4
  133. torchzero-0.3.14/torchzero.egg-info/PKG-INFO +14 -0
  134. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero.egg-info/SOURCES.txt +54 -47
  135. torchzero-0.3.11/LICENSE +0 -21
  136. torchzero-0.3.11/PKG-INFO +0 -404
  137. torchzero-0.3.11/README.md +0 -365
  138. torchzero-0.3.11/docs/source/conf.py +0 -59
  139. torchzero-0.3.11/docs/source/docstring template.py +0 -46
  140. torchzero-0.3.11/torchzero/__init__.py +0 -4
  141. torchzero-0.3.11/torchzero/core/__init__.py +0 -2
  142. torchzero-0.3.11/torchzero/modules/experimental/__init__.py +0 -41
  143. torchzero-0.3.11/torchzero/modules/experimental/absoap.py +0 -253
  144. torchzero-0.3.11/torchzero/modules/experimental/adadam.py +0 -118
  145. torchzero-0.3.11/torchzero/modules/experimental/adamY.py +0 -131
  146. torchzero-0.3.11/torchzero/modules/experimental/adam_lambertw.py +0 -149
  147. torchzero-0.3.11/torchzero/modules/experimental/adaptive_step_size.py +0 -90
  148. torchzero-0.3.11/torchzero/modules/experimental/adasoap.py +0 -177
  149. torchzero-0.3.11/torchzero/modules/experimental/cosine.py +0 -214
  150. torchzero-0.3.11/torchzero/modules/experimental/cubic_adam.py +0 -97
  151. torchzero-0.3.11/torchzero/modules/experimental/eigendescent.py +0 -120
  152. torchzero-0.3.11/torchzero/modules/experimental/etf.py +0 -195
  153. torchzero-0.3.11/torchzero/modules/experimental/exp_adam.py +0 -113
  154. torchzero-0.3.11/torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  155. torchzero-0.3.11/torchzero/modules/experimental/hnewton.py +0 -85
  156. torchzero-0.3.11/torchzero/modules/experimental/modular_lbfgs.py +0 -265
  157. torchzero-0.3.11/torchzero/modules/experimental/newton_solver.py +0 -88
  158. torchzero-0.3.11/torchzero/modules/experimental/parabolic_search.py +0 -220
  159. torchzero-0.3.11/torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  160. torchzero-0.3.11/torchzero/modules/experimental/tensor_adagrad.py +0 -42
  161. torchzero-0.3.11/torchzero/modules/higher_order/__init__.py +0 -1
  162. torchzero-0.3.11/torchzero/modules/line_search/adaptive.py +0 -99
  163. torchzero-0.3.11/torchzero/modules/line_search/line_search.py +0 -239
  164. torchzero-0.3.11/torchzero/modules/line_search/polynomial.py +0 -233
  165. torchzero-0.3.11/torchzero/modules/line_search/strong_wolfe.py +0 -276
  166. torchzero-0.3.11/torchzero/modules/misc/gradient_accumulation.py +0 -70
  167. torchzero-0.3.11/torchzero/modules/misc/split.py +0 -103
  168. torchzero-0.3.11/torchzero/modules/momentum/__init__.py +0 -14
  169. torchzero-0.3.11/torchzero/modules/momentum/matrix_momentum.py +0 -193
  170. torchzero-0.3.11/torchzero/modules/optimizers/adagrad.py +0 -165
  171. torchzero-0.3.11/torchzero/modules/quasi_newton/lbfgs.py +0 -286
  172. torchzero-0.3.11/torchzero/modules/quasi_newton/lsr1.py +0 -218
  173. torchzero-0.3.11/torchzero/modules/quasi_newton/trust_region.py +0 -397
  174. torchzero-0.3.11/torchzero/modules/second_order/__init__.py +0 -3
  175. torchzero-0.3.11/torchzero/modules/second_order/newton_cg.py +0 -374
  176. torchzero-0.3.11/torchzero/modules/smoothing/gaussian.py +0 -198
  177. torchzero-0.3.11/torchzero/modules/step_size/adaptive.py +0 -122
  178. torchzero-0.3.11/torchzero/utils/__init__.py +0 -23
  179. torchzero-0.3.11/torchzero/utils/linalg/solve.py +0 -408
  180. torchzero-0.3.11/torchzero.egg-info/PKG-INFO +0 -404
  181. {torchzero-0.3.11 → torchzero-0.3.14}/setup.cfg +0 -0
  182. {torchzero-0.3.11 → torchzero-0.3.14}/tests/test_identical.py +0 -0
  183. {torchzero-0.3.11 → torchzero-0.3.14}/tests/test_module.py +0 -0
  184. {torchzero-0.3.11 → torchzero-0.3.14}/tests/test_utils_optimizer.py +0 -0
  185. {torchzero-0.3.11 → torchzero-0.3.14}/tests/test_vars.py +0 -0
  186. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/lion.py +0 -0
  187. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/orthograd.py +0 -0
  188. {torchzero-0.3.11/torchzero/modules/optimizers → torchzero-0.3.14/torchzero/modules/adaptive}/rmsprop.py +0 -0
  189. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/clipping/__init__.py +0 -0
  190. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/clipping/growth_clipping.py +0 -0
  191. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/curveball.py +0 -0
  192. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  193. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/accumulate.py +0 -0
  194. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/ops/unary.py +0 -0
  195. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/projections/__init__.py +0 -0
  196. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/projections/cast.py +0 -0
  197. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/projections/galore.py +0 -0
  198. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/second_order/nystrom.py +0 -0
  199. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/smoothing/laplacian.py +0 -0
  200. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/step_size/lr.py +0 -0
  201. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/weight_decay/__init__.py +0 -0
  202. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/wrappers/__init__.py +0 -0
  203. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
  204. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/__init__.py +0 -0
  205. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/utility/__init__.py +0 -0
  206. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/__init__.py +0 -0
  207. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/mads.py +0 -0
  208. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/optim/wrappers/nevergrad.py +0 -0
  209. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/linalg/benchmark.py +0 -0
  210. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/linalg/svd.py +0 -0
  211. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/numberlist.py +0 -0
  212. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/ops.py +0 -0
  213. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/optuna_tools.py +0 -0
  214. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero/utils/params.py +0 -0
  215. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero.egg-info/dependency_links.txt +0 -0
  216. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero.egg-info/requires.txt +0 -0
  217. {torchzero-0.3.11 → torchzero-0.3.14}/torchzero.egg-info/top_level.txt +0 -0
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchzero
3
+ Version: 0.3.14
4
+ Summary: Modular optimization library for PyTorch.
5
+ Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/inikishev/torchzero
7
+ Project-URL: Repository, https://github.com/inikishev/torchzero
8
+ Project-URL: Issues, https://github.com/inikishev/torchzero/isses
9
+ Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: torch
13
+ Requires-Dist: numpy
14
+ Requires-Dist: typing_extensions
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.11"
16
+ version = "0.3.14"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -56,14 +56,17 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
56
56
  if use_closure:
57
57
  def closure(backward=True):
58
58
  loss = objective()
59
+ losses.append(loss.detach())
59
60
  if backward:
60
61
  opt.zero_grad()
61
62
  loss.backward()
62
63
  return loss
63
- loss = opt.step(closure)
64
- assert loss is not None
65
- assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
66
- losses.append(loss)
64
+ ret = opt.step(closure)
65
+ assert ret is not None # the return should be the loss
66
+ with torch.no_grad():
67
+ loss = objective() # in case f(x_0) is not evaluated
68
+ assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
+ losses.append(loss.detach())
67
70
 
68
71
  else:
69
72
  loss = objective()
@@ -71,7 +74,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
71
74
  loss.backward()
72
75
  opt.step()
73
76
  assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
74
- losses.append(loss)
77
+ losses.append(loss.detach())
75
78
 
76
79
  losses.append(objective())
77
80
  return torch.stack(losses).nan_to_num(0,10000,10000).min()
@@ -374,6 +377,21 @@ RandomizedFDM_central4 = Run(
374
377
  func='booth', steps=50, loss=10, merge_invariant=True,
375
378
  sphere_steps=100, sphere_loss=450,
376
379
  )
380
+ RandomizedFDM_forward4 = Run(
381
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
382
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
383
+ needs_closure=True,
384
+ func='booth', steps=50, loss=10, merge_invariant=True,
385
+ sphere_steps=100, sphere_loss=450,
386
+ )
387
+ RandomizedFDM_forward5 = Run(
388
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
389
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
390
+ needs_closure=True,
391
+ func='booth', steps=50, loss=10, merge_invariant=True,
392
+ sphere_steps=100, sphere_loss=450,
393
+ )
394
+
377
395
 
378
396
  RandomizedFDM_4samples = Run(
379
397
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
@@ -382,13 +400,6 @@ RandomizedFDM_4samples = Run(
382
400
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
383
401
  sphere_steps=100, sphere_loss=400,
384
402
  )
385
- RandomizedFDM_4samples_lerp = Run(
386
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.99, seed=0), tz.m.LR(0.1)),
387
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.9, seed=0), tz.m.LR(0.001)),
388
- needs_closure=True,
389
- func='booth', steps=50, loss=1e-5, merge_invariant=True,
390
- sphere_steps=100, sphere_loss=505,
391
- )
392
403
  RandomizedFDM_4samples_no_pre_generate = Run(
393
404
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
394
405
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
@@ -455,25 +466,11 @@ Backtracking = Run(
455
466
  func='booth', steps=50, loss=0, merge_invariant=True,
456
467
  sphere_steps=2, sphere_loss=0,
457
468
  )
458
- Backtracking_try_negative = Run(
459
- func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
460
- sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
461
- needs_closure=True,
462
- func='booth', steps=50, loss=1e-9, merge_invariant=True,
463
- sphere_steps=2, sphere_loss=1e-10,
464
- )
465
469
  AdaptiveBacktracking = Run(
466
470
  func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
467
471
  sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
468
472
  needs_closure=True,
469
- func='booth', steps=50, loss=1e-12, merge_invariant=True,
470
- sphere_steps=2, sphere_loss=1e-10,
471
- )
472
- AdaptiveBacktracking_try_negative = Run(
473
- func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
474
- sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
475
- needs_closure=True,
476
- func='booth', steps=50, loss=1e-8, merge_invariant=True,
473
+ func='booth', steps=50, loss=1e-11, merge_invariant=True,
477
474
  sphere_steps=2, sphere_loss=1e-10,
478
475
  )
479
476
  # ----------------------------- line_search/scipy ---------------------------- #
@@ -578,8 +575,8 @@ UpdateGradientSignConsistency = Run(
578
575
  sphere_steps=10, sphere_loss=2,
579
576
  )
580
577
  IntermoduleCautious = Run(
581
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.01)),
582
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.1)),
578
+ func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
579
+ sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
583
580
  needs_closure=False,
584
581
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
585
582
  sphere_steps=10, sphere_loss=0.1,
@@ -592,8 +589,8 @@ ScaleByGradCosineSimilarity = Run(
592
589
  sphere_steps=10, sphere_loss=0.1,
593
590
  )
594
591
  ScaleModulesByCosineSimilarity = Run(
595
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.05)),
596
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.1)),
592
+ func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
593
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
597
594
  needs_closure=False,
598
595
  func='booth', steps=50, loss=0.005, merge_invariant=True,
599
596
  sphere_steps=10, sphere_loss=0.1,
@@ -601,47 +598,69 @@ ScaleModulesByCosineSimilarity = Run(
601
598
 
602
599
  # ------------------------- momentum/matrix_momentum ------------------------- #
603
600
  MatrixMomentum_forward = Run(
604
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
605
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
601
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
602
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
606
603
  needs_closure=True,
607
604
  func='booth', steps=50, loss=0.05, merge_invariant=True,
608
- sphere_steps=10, sphere_loss=0,
605
+ sphere_steps=10, sphere_loss=0.01,
609
606
  )
610
607
  MatrixMomentum_forward = Run(
611
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
612
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
608
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
609
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
613
610
  needs_closure=True,
614
611
  func='booth', steps=50, loss=0.05, merge_invariant=True,
615
- sphere_steps=10, sphere_loss=0,
612
+ sphere_steps=10, sphere_loss=0.01,
616
613
  )
617
614
  MatrixMomentum_forward = Run(
618
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
619
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
615
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
616
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
620
617
  needs_closure=True,
621
618
  func='booth', steps=50, loss=0.05, merge_invariant=True,
622
- sphere_steps=10, sphere_loss=0,
619
+ sphere_steps=10, sphere_loss=0.01,
623
620
  )
624
621
 
625
622
  AdaptiveMatrixMomentum_forward = Run(
626
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
627
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
623
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
624
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
628
625
  needs_closure=True,
629
- func='booth', steps=50, loss=0.002, merge_invariant=True,
630
- sphere_steps=10, sphere_loss=0,
626
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
627
+ sphere_steps=10, sphere_loss=0.05,
631
628
  )
632
629
  AdaptiveMatrixMomentum_central = Run(
633
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
634
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
630
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
631
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
635
632
  needs_closure=True,
636
- func='booth', steps=50, loss=0.002, merge_invariant=True,
637
- sphere_steps=10, sphere_loss=0,
633
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
634
+ sphere_steps=10, sphere_loss=0.05,
638
635
  )
639
636
  AdaptiveMatrixMomentum_autograd = Run(
640
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
641
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
637
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
638
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
642
639
  needs_closure=True,
643
- func='booth', steps=50, loss=0.002, merge_invariant=True,
644
- sphere_steps=10, sphere_loss=0,
640
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
641
+ sphere_steps=10, sphere_loss=0.05,
642
+ )
643
+
644
+ StochasticAdaptiveMatrixMomentum_forward = Run(
645
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
646
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
647
+ needs_closure=True,
648
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
649
+ sphere_steps=10, sphere_loss=0.05,
650
+ )
651
+ StochasticAdaptiveMatrixMomentum_central = Run(
652
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
653
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
654
+ needs_closure=True,
655
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
656
+ sphere_steps=10, sphere_loss=0.05,
657
+ )
658
+ StochasticAdaptiveMatrixMomentum_autograd = Run(
659
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
660
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
661
+ needs_closure=True,
662
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
663
+ sphere_steps=10, sphere_loss=0.05,
645
664
  )
646
665
 
647
666
  # EMA, momentum are covered by test_identical
@@ -668,8 +687,8 @@ UpdateSign = Run(
668
687
  sphere_steps=10, sphere_loss=0,
669
688
  )
670
689
  GradAccumulation = Run(
671
- func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.05), 10), ),
672
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.5), 10), ),
690
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
691
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
673
692
  needs_closure=False,
674
693
  func='booth', steps=50, loss=25, merge_invariant=True,
675
694
  sphere_steps=20, sphere_loss=1e-11,
@@ -725,24 +744,24 @@ Shampoo = Run(
725
744
 
726
745
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
727
746
  BFGS = Run(
728
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
729
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
747
+ func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
748
+ sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
730
749
  needs_closure=True,
731
750
  func='rosen', steps=50, loss=1e-10, merge_invariant=True,
732
751
  sphere_steps=10, sphere_loss=1e-10,
733
752
  )
734
753
  SR1 = Run(
735
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
736
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
754
+ func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
755
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
737
756
  needs_closure=True,
738
757
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
739
758
  sphere_steps=10, sphere_loss=0,
740
759
  )
741
760
  SSVM = Run(
742
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
743
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
761
+ func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
762
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
744
763
  needs_closure=True,
745
- func='rosen', steps=50, loss=0.5, merge_invariant=True,
764
+ func='rosen', steps=50, loss=0.2, merge_invariant=True,
746
765
  sphere_steps=10, sphere_loss=0,
747
766
  )
748
767
 
@@ -757,8 +776,8 @@ LBFGS = Run(
757
776
 
758
777
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
759
778
  LSR1 = Run(
760
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
761
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
779
+ func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
780
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
762
781
  needs_closure=True,
763
782
  func='rosen', steps=50, loss=0, merge_invariant=True,
764
783
  sphere_steps=10, sphere_loss=0,
@@ -775,8 +794,8 @@ LSR1 = Run(
775
794
 
776
795
  # ---------------------------- second_order/newton --------------------------- #
777
796
  Newton = Run(
778
- func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
779
- sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
797
+ func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
798
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
780
799
  needs_closure=True,
781
800
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
782
801
  sphere_steps=2, sphere_loss=1e-9,
@@ -784,8 +803,8 @@ Newton = Run(
784
803
 
785
804
  # --------------------------- second_order/newton_cg -------------------------- #
786
805
  NewtonCG = Run(
787
- func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
788
- sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
806
+ func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
807
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
789
808
  needs_closure=True,
790
809
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
791
810
  sphere_steps=2, sphere_loss=3e-4,
@@ -793,11 +812,11 @@ NewtonCG = Run(
793
812
 
794
813
  # ---------------------------- smoothing/gaussian ---------------------------- #
795
814
  GaussianHomotopy = Run(
796
- func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
797
- sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
815
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
816
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
798
817
  needs_closure=True,
799
- func='booth', steps=20, loss=0.1, merge_invariant=True,
800
- sphere_steps=10, sphere_loss=200,
818
+ func='booth', steps=20, loss=0.01, merge_invariant=True,
819
+ sphere_steps=10, sphere_loss=1,
801
820
  )
802
821
 
803
822
  # ---------------------------- smoothing/laplacian --------------------------- #
@@ -879,14 +898,14 @@ Adan = Run(
879
898
  )
880
899
 
881
900
  # ------------------------------------ CGs ----------------------------------- #
882
- for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
901
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.DYHS, tz.m.ProjectedGradientMethod):
883
902
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
884
903
  # but also test 10 to make sure it doesn't explode after converging
885
904
  Run(
886
905
  func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
887
906
  sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
888
907
  needs_closure=True,
889
- func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=False, # strong wolfe adds float imprecision
908
+ func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
890
909
  sphere_steps=sphere_steps_, sphere_loss=0,
891
910
  )
892
911
 
@@ -917,10 +936,10 @@ for QN in (
917
936
  tz.m.SSVM,
918
937
  ):
919
938
  Run(
920
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
921
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
939
+ func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
940
+ sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
922
941
  needs_closure=True,
923
- func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
942
+ func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
924
943
  sphere_steps=10, sphere_loss=1e-20,
925
944
  )
926
945
 
@@ -977,22 +977,23 @@ def test_rademacher_like(big_tl: TensorList):
977
977
 
978
978
  @pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
979
979
  def test_sample_like(simple_tl: TensorList, dist):
980
- eps_scalar = 2.0
981
- result_tl_scalar = simple_tl.sample_like(eps_scalar, distribution=dist)
980
+ eps_scalar = 1
981
+ result_tl_scalar = simple_tl.sample_like(distribution=dist)
982
982
  assert isinstance(result_tl_scalar, TensorList)
983
983
  assert result_tl_scalar.shape == simple_tl.shape
984
984
 
985
- eps_list = [0.5, 1.0, 1.5]
986
- result_tl_list = simple_tl.sample_like(eps_list, distribution=dist)
985
+ eps_list = [1.0,]
986
+ result_tl_list = simple_tl.sample_like(distribution=dist)
987
987
  assert isinstance(result_tl_list, TensorList)
988
988
  assert result_tl_list.shape == simple_tl.shape
989
989
 
990
990
  # Basic checks based on distribution
991
991
  if dist == 'uniform':
992
- assert all(torch.all((t >= -eps_scalar/2) & (t <= eps_scalar/2)) for t in result_tl_scalar)
993
- assert all(torch.all((t >= -e/2) & (t <= e/2)) for t, e in zip(result_tl_list, eps_list))
992
+ assert all(torch.all((t >= -eps_scalar) & (t <= eps_scalar)) for t in result_tl_scalar)
993
+ assert all(torch.all((t >= -e) & (t <= e)) for t, e in zip(result_tl_list, eps_list))
994
994
  elif dist == 'sphere':
995
- assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
995
+ # assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
996
+ pass
996
997
  # Cannot check list version easily
997
998
  elif dist == 'rademacher':
998
999
  assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
@@ -0,0 +1,4 @@
1
+ from . import core, optim, utils
2
+ from .core import Modular
3
+ from .utils import set_compilation
4
+ from . import modules as m
@@ -0,0 +1,2 @@
1
+ from .module import Chain, Chainable, Modular, Module, Var, maybe_chain
2
+ from .transform import Target, TensorwiseTransform, Transform, apply_transform