torchzero 0.3.15__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. {torchzero-0.3.15 → torchzero-0.4.1}/PKG-INFO +1 -1
  2. {torchzero-0.3.15 → torchzero-0.4.1}/pyproject.toml +1 -1
  3. {torchzero-0.3.15 → torchzero-0.4.1}/tests/test_identical.py +22 -22
  4. torchzero-0.4.1/tests/test_module_autograd.py +586 -0
  5. torchzero-0.4.1/tests/test_objective.py +188 -0
  6. {torchzero-0.3.15 → torchzero-0.4.1}/tests/test_opts.py +225 -214
  7. {torchzero-0.3.15 → torchzero-0.4.1}/tests/test_tensorlist.py +0 -8
  8. {torchzero-0.3.15 → torchzero-0.4.1}/tests/test_utils_optimizer.py +0 -1
  9. torchzero-0.4.1/torchzero/__init__.py +4 -0
  10. torchzero-0.4.1/torchzero/core/__init__.py +8 -0
  11. torchzero-0.4.1/torchzero/core/chain.py +47 -0
  12. torchzero-0.4.1/torchzero/core/functional.py +103 -0
  13. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/core/modular.py +53 -57
  14. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/core/module.py +132 -52
  15. torchzero-0.4.1/torchzero/core/objective.py +948 -0
  16. torchzero-0.4.1/torchzero/core/reformulation.py +98 -0
  17. torchzero-0.4.1/torchzero/core/transform.py +336 -0
  18. torchzero-0.4.1/torchzero/linalg/__init__.py +11 -0
  19. torchzero-0.4.1/torchzero/linalg/eigh.py +253 -0
  20. torchzero-0.4.1/torchzero/linalg/linalg_utils.py +14 -0
  21. {torchzero-0.3.15/torchzero/utils → torchzero-0.4.1/torchzero}/linalg/linear_operator.py +99 -49
  22. torchzero-0.4.1/torchzero/linalg/matrix_power.py +28 -0
  23. torchzero-0.4.1/torchzero/linalg/orthogonalize.py +93 -0
  24. {torchzero-0.3.15/torchzero/utils → torchzero-0.4.1/torchzero}/linalg/qr.py +16 -2
  25. {torchzero-0.3.15/torchzero/utils → torchzero-0.4.1/torchzero}/linalg/solve.py +74 -88
  26. torchzero-0.4.1/torchzero/linalg/svd.py +47 -0
  27. torchzero-0.4.1/torchzero/linalg/torch_linalg.py +168 -0
  28. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/__init__.py +4 -3
  29. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/__init__.py +11 -3
  30. torchzero-0.4.1/torchzero/modules/adaptive/adagrad.py +306 -0
  31. torchzero-0.4.1/torchzero/modules/adaptive/adahessian.py +195 -0
  32. torchzero-0.4.1/torchzero/modules/adaptive/adam.py +84 -0
  33. torchzero-0.4.1/torchzero/modules/adaptive/adan.py +115 -0
  34. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  35. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/aegd.py +12 -12
  36. torchzero-0.4.1/torchzero/modules/adaptive/esgd.py +150 -0
  37. torchzero-0.4.1/torchzero/modules/adaptive/ggt.py +186 -0
  38. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/lion.py +7 -11
  39. torchzero-0.4.1/torchzero/modules/adaptive/lre_optimizers.py +299 -0
  40. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/mars.py +7 -7
  41. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/matrix_momentum.py +48 -52
  42. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/msam.py +71 -53
  43. torchzero-0.4.1/torchzero/modules/adaptive/muon.py +184 -0
  44. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/natural_gradient.py +63 -41
  45. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/orthograd.py +11 -15
  46. torchzero-0.4.1/torchzero/modules/adaptive/psgd/__init__.py +5 -0
  47. torchzero-0.4.1/torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  48. torchzero-0.4.1/torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  49. torchzero-0.4.1/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  50. torchzero-0.4.1/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  51. torchzero-0.4.1/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  52. torchzero-0.4.1/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  53. torchzero-0.4.1/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  54. torchzero-0.4.1/torchzero/modules/adaptive/rmsprop.py +111 -0
  55. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/rprop.py +48 -47
  56. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/adaptive/sam.py +55 -45
  57. torchzero-0.4.1/torchzero/modules/adaptive/shampoo.py +248 -0
  58. torchzero-0.4.1/torchzero/modules/adaptive/soap.py +329 -0
  59. torchzero-0.4.1/torchzero/modules/adaptive/sophia_h.py +161 -0
  60. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/clipping/clipping.py +22 -25
  61. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/clipping/ema_clipping.py +31 -25
  62. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/clipping/growth_clipping.py +14 -17
  63. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/conjugate_gradient/cg.py +27 -38
  64. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/__init__.py +7 -6
  65. torchzero-0.4.1/torchzero/modules/experimental/adanystrom.py +258 -0
  66. torchzero-0.4.1/torchzero/modules/experimental/common_directions_whiten.py +142 -0
  67. torchzero-0.4.1/torchzero/modules/experimental/coordinate_momentum.py +36 -0
  68. torchzero-0.4.1/torchzero/modules/experimental/cubic_adam.py +160 -0
  69. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/curveball.py +25 -41
  70. torchzero-0.4.1/torchzero/modules/experimental/eigen_sr1.py +182 -0
  71. torchzero-0.4.1/torchzero/modules/experimental/eigengrad.py +207 -0
  72. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/gradmin.py +2 -2
  73. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/higher_order_newton.py +14 -40
  74. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/l_infinity.py +1 -1
  75. torchzero-0.4.1/torchzero/modules/experimental/matrix_nag.py +122 -0
  76. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/newton_solver.py +23 -54
  77. torchzero-0.4.1/torchzero/modules/experimental/newtonnewton.py +102 -0
  78. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  79. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  80. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/spsa1.py +3 -3
  81. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/structural_projections.py +1 -4
  82. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/grad_approximation/fdm.py +2 -2
  83. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  84. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  85. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/grad_approximation/rfdm.py +24 -21
  86. torchzero-0.4.1/torchzero/modules/least_squares/gn.py +232 -0
  87. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/backtracking.py +4 -4
  88. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/line_search.py +33 -33
  89. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/strong_wolfe.py +4 -4
  90. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/debug.py +12 -12
  91. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/escape.py +10 -10
  92. torchzero-0.4.1/torchzero/modules/misc/gradient_accumulation.py +68 -0
  93. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/homotopy.py +16 -8
  94. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/misc.py +121 -123
  95. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/multistep.py +52 -53
  96. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/regularization.py +49 -44
  97. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/split.py +31 -29
  98. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/switch.py +37 -32
  99. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/momentum/averaging.py +14 -14
  100. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/momentum/cautious.py +37 -31
  101. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/momentum/momentum.py +12 -12
  102. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/__init__.py +4 -4
  103. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/accumulate.py +21 -21
  104. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/binary.py +67 -66
  105. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/higher_level.py +20 -20
  106. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/multi.py +44 -41
  107. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/reduce.py +26 -23
  108. torchzero-0.4.1/torchzero/modules/ops/unary.py +131 -0
  109. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/ops/utility.py +47 -46
  110. torchzero-0.3.15/torchzero/modules/functional.py → torchzero-0.4.1/torchzero/modules/opt_utils.py +1 -1
  111. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/projections/galore.py +1 -1
  112. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/projections/projection.py +46 -43
  113. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/quasi_newton/__init__.py +1 -1
  114. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/quasi_newton/damping.py +2 -2
  115. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  116. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/quasi_newton/lbfgs.py +10 -10
  117. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/quasi_newton/lsr1.py +10 -10
  118. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  119. torchzero-0.4.1/torchzero/modules/quasi_newton/sg2.py +156 -0
  120. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/restarts/restars.py +39 -37
  121. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/second_order/__init__.py +2 -2
  122. torchzero-0.4.1/torchzero/modules/second_order/ifn.py +58 -0
  123. torchzero-0.4.1/torchzero/modules/second_order/inm.py +109 -0
  124. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/second_order/multipoint.py +40 -80
  125. torchzero-0.4.1/torchzero/modules/second_order/newton.py +262 -0
  126. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/second_order/newton_cg.py +105 -157
  127. torchzero-0.4.1/torchzero/modules/second_order/nystrom.py +302 -0
  128. torchzero-0.4.1/torchzero/modules/second_order/rsn.py +234 -0
  129. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/smoothing/laplacian.py +13 -12
  130. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/smoothing/sampling.py +10 -10
  131. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/step_size/adaptive.py +24 -24
  132. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/step_size/lr.py +17 -17
  133. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/termination/termination.py +32 -30
  134. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/trust_region/cubic_regularization.py +3 -3
  135. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  136. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/trust_region/trust_cg.py +2 -2
  137. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/trust_region/trust_region.py +27 -22
  138. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/variance_reduction/svrg.py +23 -21
  139. torchzero-0.4.1/torchzero/modules/weight_decay/__init__.py +2 -0
  140. torchzero-0.4.1/torchzero/modules/weight_decay/reinit.py +83 -0
  141. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/weight_decay/weight_decay.py +17 -18
  142. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/wrappers/optim_wrapper.py +14 -14
  143. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/zeroth_order/cd.py +10 -7
  144. torchzero-0.4.1/torchzero/optim/mbs.py +291 -0
  145. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/root.py +3 -3
  146. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/utility/split.py +2 -1
  147. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/directsearch.py +27 -63
  148. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/fcmaes.py +14 -35
  149. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/mads.py +11 -31
  150. torchzero-0.4.1/torchzero/optim/wrappers/moors.py +66 -0
  151. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/nevergrad.py +4 -13
  152. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/nlopt.py +31 -25
  153. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/optuna.py +8 -13
  154. torchzero-0.4.1/torchzero/optim/wrappers/pybobyqa.py +124 -0
  155. torchzero-0.4.1/torchzero/optim/wrappers/scipy/__init__.py +7 -0
  156. torchzero-0.4.1/torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  157. torchzero-0.4.1/torchzero/optim/wrappers/scipy/brute.py +48 -0
  158. torchzero-0.4.1/torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  159. torchzero-0.4.1/torchzero/optim/wrappers/scipy/direct.py +69 -0
  160. torchzero-0.4.1/torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  161. torchzero-0.4.1/torchzero/optim/wrappers/scipy/experimental.py +141 -0
  162. torchzero-0.4.1/torchzero/optim/wrappers/scipy/minimize.py +151 -0
  163. torchzero-0.4.1/torchzero/optim/wrappers/scipy/sgho.py +111 -0
  164. torchzero-0.4.1/torchzero/optim/wrappers/wrapper.py +121 -0
  165. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/__init__.py +7 -25
  166. torchzero-0.4.1/torchzero/utils/benchmarks/__init__.py +0 -0
  167. torchzero-0.4.1/torchzero/utils/benchmarks/logistic.py +122 -0
  168. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/compile.py +2 -2
  169. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/derivatives.py +97 -73
  170. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/optimizer.py +4 -77
  171. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/python_tools.py +31 -0
  172. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/tensorlist.py +11 -5
  173. torchzero-0.4.1/torchzero/utils/thoad_tools.py +68 -0
  174. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero.egg-info/PKG-INFO +1 -1
  175. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero.egg-info/SOURCES.txt +49 -15
  176. torchzero-0.3.15/tests/test_vars.py +0 -185
  177. torchzero-0.3.15/torchzero/__init__.py +0 -4
  178. torchzero-0.3.15/torchzero/core/__init__.py +0 -5
  179. torchzero-0.3.15/torchzero/core/chain.py +0 -50
  180. torchzero-0.3.15/torchzero/core/functional.py +0 -37
  181. torchzero-0.3.15/torchzero/core/reformulation.py +0 -67
  182. torchzero-0.3.15/torchzero/core/transform.py +0 -442
  183. torchzero-0.3.15/torchzero/core/var.py +0 -376
  184. torchzero-0.3.15/torchzero/modules/adaptive/adagrad.py +0 -356
  185. torchzero-0.3.15/torchzero/modules/adaptive/adahessian.py +0 -224
  186. torchzero-0.3.15/torchzero/modules/adaptive/adam.py +0 -107
  187. torchzero-0.3.15/torchzero/modules/adaptive/adan.py +0 -96
  188. torchzero-0.3.15/torchzero/modules/adaptive/esgd.py +0 -171
  189. torchzero-0.3.15/torchzero/modules/adaptive/lmadagrad.py +0 -186
  190. torchzero-0.3.15/torchzero/modules/adaptive/muon.py +0 -246
  191. torchzero-0.3.15/torchzero/modules/adaptive/rmsprop.py +0 -103
  192. torchzero-0.3.15/torchzero/modules/adaptive/shampoo.py +0 -229
  193. torchzero-0.3.15/torchzero/modules/adaptive/soap.py +0 -265
  194. torchzero-0.3.15/torchzero/modules/adaptive/sophia_h.py +0 -185
  195. torchzero-0.3.15/torchzero/modules/experimental/momentum.py +0 -160
  196. torchzero-0.3.15/torchzero/modules/experimental/newtonnewton.py +0 -105
  197. torchzero-0.3.15/torchzero/modules/least_squares/gn.py +0 -161
  198. torchzero-0.3.15/torchzero/modules/misc/gradient_accumulation.py +0 -136
  199. torchzero-0.3.15/torchzero/modules/ops/unary.py +0 -131
  200. torchzero-0.3.15/torchzero/modules/quasi_newton/sg2.py +0 -292
  201. torchzero-0.3.15/torchzero/modules/second_order/ifn.py +0 -89
  202. torchzero-0.3.15/torchzero/modules/second_order/inm.py +0 -105
  203. torchzero-0.3.15/torchzero/modules/second_order/newton.py +0 -293
  204. torchzero-0.3.15/torchzero/modules/second_order/nystrom.py +0 -271
  205. torchzero-0.3.15/torchzero/modules/second_order/rsn.py +0 -227
  206. torchzero-0.3.15/torchzero/modules/weight_decay/__init__.py +0 -1
  207. torchzero-0.3.15/torchzero/optim/wrappers/scipy.py +0 -572
  208. torchzero-0.3.15/torchzero/utils/linalg/__init__.py +0 -12
  209. torchzero-0.3.15/torchzero/utils/linalg/matrix_funcs.py +0 -87
  210. torchzero-0.3.15/torchzero/utils/linalg/orthogonalize.py +0 -12
  211. torchzero-0.3.15/torchzero/utils/linalg/svd.py +0 -20
  212. torchzero-0.3.15/torchzero/utils/ops.py +0 -10
  213. {torchzero-0.3.15 → torchzero-0.4.1}/setup.cfg +0 -0
  214. {torchzero-0.3.15 → torchzero-0.4.1}/tests/test_module.py +0 -0
  215. {torchzero-0.3.15/torchzero/utils → torchzero-0.4.1/torchzero}/linalg/benchmark.py +0 -0
  216. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/clipping/__init__.py +0 -0
  217. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
  218. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/dct.py +0 -0
  219. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/experimental/fft.py +0 -0
  220. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/grad_approximation/__init__.py +0 -0
  221. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/least_squares/__init__.py +0 -0
  222. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/__init__.py +0 -0
  223. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/_polyinterp.py +0 -0
  224. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/adaptive.py +0 -0
  225. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/interpolation.py +0 -0
  226. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/line_search/scipy.py +0 -0
  227. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/misc/__init__.py +0 -0
  228. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/momentum/__init__.py +0 -0
  229. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/projections/__init__.py +0 -0
  230. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/projections/cast.py +0 -0
  231. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/restarts/__init__.py +0 -0
  232. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/smoothing/__init__.py +0 -0
  233. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/step_size/__init__.py +0 -0
  234. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/termination/__init__.py +0 -0
  235. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/trust_region/__init__.py +0 -0
  236. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/trust_region/dogleg.py +0 -0
  237. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/variance_reduction/__init__.py +0 -0
  238. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/wrappers/__init__.py +0 -0
  239. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/modules/zeroth_order/__init__.py +0 -0
  240. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/__init__.py +0 -0
  241. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/utility/__init__.py +0 -0
  242. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/optim/wrappers/__init__.py +0 -0
  243. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/metrics.py +0 -0
  244. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/numberlist.py +0 -0
  245. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/optuna_tools.py +0 -0
  246. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/params.py +0 -0
  247. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero/utils/torch_tools.py +0 -0
  248. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero.egg-info/dependency_links.txt +0 -0
  249. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero.egg-info/requires.txt +0 -0
  250. {torchzero-0.3.15 → torchzero-0.4.1}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.15
3
+ Version: 0.4.1
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.15"
16
+ version = "0.4.1"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -97,30 +97,30 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
99
  torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
101
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
102
- tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
103
- tz_fn4 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
104
- tz_fn5 = lambda p: tz.Modular(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
105
- tz_fn_ops = lambda p: tz.Modular(
100
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad))
101
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
102
+ tz_fn3 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
103
+ tz_fn4 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
104
+ tz_fn5 = lambda p: tz.Optimizer(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
105
+ tz_fn_ops = lambda p: tz.Optimizer(
106
106
  p,
107
107
  tz.m.DivModules(
108
108
  tz.m.EMA(0.9, debiased=True),
109
109
  [tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
110
110
  ))
111
- tz_fn_ops2 = lambda p: tz.Modular(
111
+ tz_fn_ops2 = lambda p: tz.Optimizer(
112
112
  p,
113
113
  tz.m.DivModules(
114
114
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
115
115
  [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Debias2(beta=0.999), tz.m.Add(1e-8)]
116
116
  ))
117
- tz_fn_ops3 = lambda p: tz.Modular(
117
+ tz_fn_ops3 = lambda p: tz.Optimizer(
118
118
  p,
119
119
  tz.m.DivModules(
120
120
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
121
121
  [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
122
122
  ))
123
- tz_fn_ops4 = lambda p: tz.Modular(
123
+ tz_fn_ops4 = lambda p: tz.Optimizer(
124
124
  p,
125
125
  tz.m.DivModules(
126
126
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
@@ -145,19 +145,19 @@ def test_adam(amsgrad):
145
145
  @pytest.mark.parametrize('amsgrad', [True, False])
146
146
  @pytest.mark.parametrize('lr', [0.1, 1])
147
147
  def test_adam_hyperparams(beta1, beta2, eps, amsgrad, lr):
148
- tz_fn = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
149
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
148
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
149
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
150
150
  _assert_identical_opts([tz_fn, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
151
151
 
152
152
  @pytest.mark.parametrize('centered', [True, False])
153
153
  def test_rmsprop(centered):
154
154
  torch_fn = lambda p: torch.optim.RMSprop(p, 1, centered=centered)
155
- tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(centered=centered, init='zeros'))
156
- tz_fn2 = lambda p: tz.Modular(
155
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(centered=centered, init='zeros'))
156
+ tz_fn2 = lambda p: tz.Optimizer(
157
157
  p,
158
158
  tz.m.Div([tz.m.CenteredSqrtEMASquared(0.99) if centered else tz.m.SqrtEMASquared(0.99), tz.m.Add(1e-8)]),
159
159
  )
160
- tz_fn3 = lambda p: tz.Modular(
160
+ tz_fn3 = lambda p: tz.Optimizer(
161
161
  p,
162
162
  tz.m.Div([tz.m.CenteredEMASquared(0.99) if centered else tz.m.EMASquared(0.99), tz.m.Sqrt(), tz.m.Add(1e-8)]),
163
163
  )
@@ -173,7 +173,7 @@ def test_rmsprop(centered):
173
173
  @pytest.mark.parametrize('centered', [True, False])
174
174
  @pytest.mark.parametrize('lr', [0.1, 1])
175
175
  def test_rmsprop_hyperparams(beta, eps, centered, lr):
176
- tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
176
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
177
177
  torch_fn = lambda p: torch.optim.RMSprop(p, lr, beta, eps=eps, centered=centered)
178
178
  _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=10)
179
179
 
@@ -185,7 +185,7 @@ def test_rmsprop_hyperparams(beta, eps, centered, lr):
185
185
  @pytest.mark.parametrize('ub', [50, 1.5])
186
186
  @pytest.mark.parametrize('lr', [0.1, 1])
187
187
  def test_rprop(nplus, nminus, lb, ub, lr):
188
- tz_fn = lambda p: tz.Modular(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
188
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
189
189
  torch_fn = lambda p: torch.optim.Rprop(p, lr, (nminus, nplus), (lb, ub))
190
190
  _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=30)
191
191
  _assert_identical_merge_closure(tz_fn, 'cpu', 30)
@@ -193,8 +193,8 @@ def test_rprop(nplus, nminus, lb, ub, lr):
193
193
 
194
194
  def test_adagrad():
195
195
  torch_fn = lambda p: torch.optim.Adagrad(p, 1)
196
- tz_fn = lambda p: tz.Modular(p, tz.m.Adagrad(), tz.m.LR(1))
197
- tz_fn2 = lambda p: tz.Modular(
196
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adagrad(), tz.m.LR(1))
197
+ tz_fn2 = lambda p: tz.Optimizer(
198
198
  p,
199
199
  tz.m.Div([tz.m.Pow(2), tz.m.AccumulateSum(), tz.m.Sqrt(), tz.m.Add(1e-10)]),
200
200
  )
@@ -212,15 +212,15 @@ def test_adagrad():
212
212
  @pytest.mark.parametrize('lr', [0.1, 1])
213
213
  def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
214
214
  torch_fn = lambda p: torch.optim.Adagrad(p, lr, initial_accumulator_value=initial_accumulator_value, eps=eps)
215
- tz_fn1 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
216
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
215
+ tz_fn1 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
216
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
217
217
  _assert_identical_opts([torch_fn, tz_fn1, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
218
218
 
219
219
 
220
220
  @pytest.mark.parametrize('tensorwise', [True, False])
221
221
  def test_graft(tensorwise):
222
- graft1 = lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
- graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.Graft([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
222
+ graft1 = lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
+ graft2 = lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
224
224
  _assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
225
225
  for fn in [graft1, graft2]:
226
226
  if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)