torchzero 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {torchzero-0.4.0 → torchzero-0.4.2}/PKG-INFO +1 -1
  2. {torchzero-0.4.0 → torchzero-0.4.2}/pyproject.toml +1 -1
  3. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_identical.py +22 -22
  4. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_opts.py +199 -198
  5. torchzero-0.4.2/torchzero/__init__.py +6 -0
  6. {torchzero-0.4.0/torchzero/optim/wrappers → torchzero-0.4.2/torchzero/_minimize}/__init__.py +0 -0
  7. torchzero-0.4.2/torchzero/_minimize/methods.py +95 -0
  8. torchzero-0.4.2/torchzero/_minimize/minimize.py +518 -0
  9. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/__init__.py +6 -6
  10. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/chain.py +2 -1
  11. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/functional.py +2 -1
  12. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/modular.py +5 -5
  13. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/module.py +77 -6
  14. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/objective.py +10 -10
  15. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/transform.py +7 -6
  16. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/__init__.py +3 -2
  17. torchzero-0.4.2/torchzero/linalg/eigh.py +301 -0
  18. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/linear_operator.py +1 -0
  19. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/orthogonalize.py +62 -9
  20. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/qr.py +12 -0
  21. torchzero-0.4.2/torchzero/linalg/sketch.py +39 -0
  22. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/solve.py +1 -3
  23. torchzero-0.4.2/torchzero/linalg/svd.py +47 -0
  24. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/__init__.py +5 -3
  25. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/__init__.py +11 -3
  26. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/adagrad.py +12 -10
  27. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/adahessian.py +2 -2
  28. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/adam.py +6 -2
  29. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/adan.py +4 -1
  30. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  31. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/esgd.py +2 -2
  32. torchzero-0.4.2/torchzero/modules/adaptive/ggt.py +188 -0
  33. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/lion.py +5 -2
  34. torchzero-0.4.2/torchzero/modules/adaptive/lre_optimizers.py +299 -0
  35. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/mars.py +8 -7
  36. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/matrix_momentum.py +1 -1
  37. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/msam.py +7 -4
  38. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/muon.py +9 -6
  39. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/natural_gradient.py +32 -15
  40. torchzero-0.4.2/torchzero/modules/adaptive/psgd/__init__.py +5 -0
  41. torchzero-0.4.2/torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  42. torchzero-0.4.2/torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  43. torchzero-0.4.2/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  44. torchzero-0.4.2/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  45. torchzero-0.4.2/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  46. torchzero-0.4.2/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  47. torchzero-0.4.2/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  48. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/rmsprop.py +2 -0
  49. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/rprop.py +11 -9
  50. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/sam.py +4 -4
  51. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/shampoo.py +37 -4
  52. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/soap.py +35 -32
  53. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/sophia_h.py +2 -2
  54. torchzero-0.4.2/torchzero/modules/basis/__init__.py +2 -0
  55. torchzero-0.4.2/torchzero/modules/basis/ggt_basis.py +199 -0
  56. torchzero-0.4.2/torchzero/modules/basis/soap_basis.py +254 -0
  57. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/clipping/clipping.py +7 -7
  58. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/clipping/ema_clipping.py +32 -27
  59. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/clipping/growth_clipping.py +1 -0
  60. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/conjugate_gradient/cg.py +2 -2
  61. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/__init__.py +2 -2
  62. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/coordinate_momentum.py +2 -0
  63. torchzero-0.4.2/torchzero/modules/experimental/cubic_adam.py +164 -0
  64. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/l_infinity.py +1 -1
  65. torchzero-0.4.2/torchzero/modules/experimental/matrix_nag.py +122 -0
  66. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/newton_solver.py +2 -2
  67. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/newtonnewton.py +34 -40
  68. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/grad_approximation/__init__.py +3 -2
  69. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/grad_approximation/fdm.py +2 -2
  70. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/grad_approximation/rfdm.py +4 -4
  71. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/least_squares/gn.py +74 -45
  72. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/backtracking.py +2 -2
  73. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/line_search.py +1 -1
  74. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/strong_wolfe.py +2 -2
  75. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/escape.py +1 -1
  76. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/gradient_accumulation.py +2 -1
  77. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/misc.py +7 -1
  78. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/multistep.py +4 -7
  79. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/regularization.py +2 -2
  80. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/split.py +1 -1
  81. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/switch.py +2 -2
  82. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/momentum/averaging.py +6 -0
  83. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/momentum/cautious.py +3 -3
  84. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/momentum/momentum.py +5 -1
  85. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/__init__.py +0 -1
  86. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/accumulate.py +4 -0
  87. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/higher_level.py +7 -2
  88. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/multi.py +1 -1
  89. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/projections/projection.py +5 -2
  90. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/quasi_newton/__init__.py +1 -1
  91. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/quasi_newton/damping.py +1 -1
  92. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  93. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/quasi_newton/lbfgs.py +3 -3
  94. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/quasi_newton/lsr1.py +3 -3
  95. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  96. torchzero-0.4.2/torchzero/modules/quasi_newton/sg2.py +156 -0
  97. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/restarts/restars.py +17 -17
  98. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/inm.py +37 -25
  99. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/newton.py +142 -132
  100. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/newton_cg.py +10 -6
  101. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/nystrom.py +80 -34
  102. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/rsn.py +74 -46
  103. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/smoothing/laplacian.py +1 -1
  104. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/smoothing/sampling.py +2 -3
  105. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/step_size/adaptive.py +6 -6
  106. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/step_size/lr.py +2 -2
  107. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/trust_region/cubic_regularization.py +1 -1
  108. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  109. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/trust_region/trust_cg.py +1 -1
  110. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/trust_region/trust_region.py +2 -1
  111. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/variance_reduction/svrg.py +4 -5
  112. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/weight_decay/reinit.py +2 -2
  113. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/weight_decay/weight_decay.py +5 -5
  114. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/wrappers/optim_wrapper.py +4 -4
  115. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/zeroth_order/cd.py +1 -1
  116. torchzero-0.4.2/torchzero/optim/mbs.py +291 -0
  117. torchzero-0.4.2/torchzero/optim/wrappers/__init__.py +0 -0
  118. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/nevergrad.py +0 -9
  119. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/optuna.py +2 -0
  120. torchzero-0.4.2/torchzero/utils/benchmarks/__init__.py +0 -0
  121. torchzero-0.4.2/torchzero/utils/benchmarks/logistic.py +137 -0
  122. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/derivatives.py +4 -4
  123. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/params.py +13 -1
  124. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/tensorlist.py +2 -2
  125. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero.egg-info/PKG-INFO +1 -1
  126. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero.egg-info/SOURCES.txt +25 -4
  127. torchzero-0.4.0/torchzero/__init__.py +0 -4
  128. torchzero-0.4.0/torchzero/linalg/eigh.py +0 -34
  129. torchzero-0.4.0/torchzero/linalg/svd.py +0 -20
  130. torchzero-0.4.0/torchzero/modules/adaptive/lmadagrad.py +0 -241
  131. torchzero-0.4.0/torchzero/modules/quasi_newton/sg2.py +0 -292
  132. {torchzero-0.4.0 → torchzero-0.4.2}/setup.cfg +0 -0
  133. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_module.py +0 -0
  134. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_module_autograd.py +0 -0
  135. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_objective.py +0 -0
  136. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_tensorlist.py +0 -0
  137. {torchzero-0.4.0 → torchzero-0.4.2}/tests/test_utils_optimizer.py +0 -0
  138. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/core/reformulation.py +0 -0
  139. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/benchmark.py +0 -0
  140. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/linalg_utils.py +0 -0
  141. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/matrix_power.py +0 -0
  142. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/linalg/torch_linalg.py +0 -0
  143. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/aegd.py +0 -0
  144. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/adaptive/orthograd.py +0 -0
  145. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/clipping/__init__.py +0 -0
  146. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
  147. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/curveball.py +0 -0
  148. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/dct.py +0 -0
  149. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/fft.py +0 -0
  150. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/gradmin.py +0 -0
  151. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/higher_order_newton.py +0 -0
  152. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  153. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
  154. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/experimental/structural_projections.py +0 -0
  155. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
  156. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  157. {torchzero-0.4.0/torchzero/modules/experimental → torchzero-0.4.2/torchzero/modules/grad_approximation}/spsa1.py +0 -0
  158. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/least_squares/__init__.py +0 -0
  159. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/__init__.py +0 -0
  160. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/_polyinterp.py +0 -0
  161. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/adaptive.py +0 -0
  162. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/interpolation.py +0 -0
  163. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/line_search/scipy.py +0 -0
  164. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/__init__.py +0 -0
  165. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/debug.py +0 -0
  166. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/misc/homotopy.py +0 -0
  167. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/momentum/__init__.py +0 -0
  168. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/binary.py +0 -0
  169. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/reduce.py +0 -0
  170. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/unary.py +0 -0
  171. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/ops/utility.py +0 -0
  172. /torchzero-0.4.0/torchzero/modules/functional.py → /torchzero-0.4.2/torchzero/modules/opt_utils.py +0 -0
  173. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/projections/__init__.py +0 -0
  174. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/projections/cast.py +0 -0
  175. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/projections/galore.py +0 -0
  176. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/restarts/__init__.py +0 -0
  177. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/__init__.py +0 -0
  178. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/ifn.py +0 -0
  179. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/second_order/multipoint.py +0 -0
  180. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/smoothing/__init__.py +0 -0
  181. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/step_size/__init__.py +0 -0
  182. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/termination/__init__.py +0 -0
  183. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/termination/termination.py +0 -0
  184. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/trust_region/__init__.py +0 -0
  185. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/trust_region/dogleg.py +0 -0
  186. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/variance_reduction/__init__.py +0 -0
  187. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/weight_decay/__init__.py +0 -0
  188. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/wrappers/__init__.py +0 -0
  189. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/modules/zeroth_order/__init__.py +0 -0
  190. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/__init__.py +0 -0
  191. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/root.py +0 -0
  192. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/utility/__init__.py +0 -0
  193. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/utility/split.py +0 -0
  194. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/directsearch.py +0 -0
  195. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/fcmaes.py +0 -0
  196. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/mads.py +0 -0
  197. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/moors.py +0 -0
  198. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/nlopt.py +0 -0
  199. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/pybobyqa.py +0 -0
  200. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/__init__.py +0 -0
  201. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/basin_hopping.py +0 -0
  202. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/brute.py +0 -0
  203. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/differential_evolution.py +0 -0
  204. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/direct.py +0 -0
  205. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/dual_annealing.py +0 -0
  206. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/experimental.py +0 -0
  207. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/minimize.py +0 -0
  208. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/scipy/sgho.py +0 -0
  209. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/optim/wrappers/wrapper.py +0 -0
  210. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/__init__.py +0 -0
  211. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/compile.py +0 -0
  212. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/metrics.py +0 -0
  213. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/numberlist.py +0 -0
  214. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/optimizer.py +0 -0
  215. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/optuna_tools.py +0 -0
  216. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/python_tools.py +0 -0
  217. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/thoad_tools.py +0 -0
  218. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero/utils/torch_tools.py +0 -0
  219. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero.egg-info/dependency_links.txt +0 -0
  220. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero.egg-info/requires.txt +0 -0
  221. {torchzero-0.4.0 → torchzero-0.4.2}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.4.0"
16
+ version = "0.4.2"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -97,30 +97,30 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
99
  torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
101
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
102
- tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
103
- tz_fn4 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
104
- tz_fn5 = lambda p: tz.Modular(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
105
- tz_fn_ops = lambda p: tz.Modular(
100
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad))
101
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
102
+ tz_fn3 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
103
+ tz_fn4 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
104
+ tz_fn5 = lambda p: tz.Optimizer(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
105
+ tz_fn_ops = lambda p: tz.Optimizer(
106
106
  p,
107
107
  tz.m.DivModules(
108
108
  tz.m.EMA(0.9, debiased=True),
109
109
  [tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
110
110
  ))
111
- tz_fn_ops2 = lambda p: tz.Modular(
111
+ tz_fn_ops2 = lambda p: tz.Optimizer(
112
112
  p,
113
113
  tz.m.DivModules(
114
114
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
115
115
  [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Debias2(beta=0.999), tz.m.Add(1e-8)]
116
116
  ))
117
- tz_fn_ops3 = lambda p: tz.Modular(
117
+ tz_fn_ops3 = lambda p: tz.Optimizer(
118
118
  p,
119
119
  tz.m.DivModules(
120
120
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
121
121
  [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
122
122
  ))
123
- tz_fn_ops4 = lambda p: tz.Modular(
123
+ tz_fn_ops4 = lambda p: tz.Optimizer(
124
124
  p,
125
125
  tz.m.DivModules(
126
126
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
@@ -145,19 +145,19 @@ def test_adam(amsgrad):
145
145
  @pytest.mark.parametrize('amsgrad', [True, False])
146
146
  @pytest.mark.parametrize('lr', [0.1, 1])
147
147
  def test_adam_hyperparams(beta1, beta2, eps, amsgrad, lr):
148
- tz_fn = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
149
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
148
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
149
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
150
150
  _assert_identical_opts([tz_fn, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
151
151
 
152
152
  @pytest.mark.parametrize('centered', [True, False])
153
153
  def test_rmsprop(centered):
154
154
  torch_fn = lambda p: torch.optim.RMSprop(p, 1, centered=centered)
155
- tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(centered=centered, init='zeros'))
156
- tz_fn2 = lambda p: tz.Modular(
155
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(centered=centered, init='zeros'))
156
+ tz_fn2 = lambda p: tz.Optimizer(
157
157
  p,
158
158
  tz.m.Div([tz.m.CenteredSqrtEMASquared(0.99) if centered else tz.m.SqrtEMASquared(0.99), tz.m.Add(1e-8)]),
159
159
  )
160
- tz_fn3 = lambda p: tz.Modular(
160
+ tz_fn3 = lambda p: tz.Optimizer(
161
161
  p,
162
162
  tz.m.Div([tz.m.CenteredEMASquared(0.99) if centered else tz.m.EMASquared(0.99), tz.m.Sqrt(), tz.m.Add(1e-8)]),
163
163
  )
@@ -173,7 +173,7 @@ def test_rmsprop(centered):
173
173
  @pytest.mark.parametrize('centered', [True, False])
174
174
  @pytest.mark.parametrize('lr', [0.1, 1])
175
175
  def test_rmsprop_hyperparams(beta, eps, centered, lr):
176
- tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
176
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
177
177
  torch_fn = lambda p: torch.optim.RMSprop(p, lr, beta, eps=eps, centered=centered)
178
178
  _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=10)
179
179
 
@@ -185,7 +185,7 @@ def test_rmsprop_hyperparams(beta, eps, centered, lr):
185
185
  @pytest.mark.parametrize('ub', [50, 1.5])
186
186
  @pytest.mark.parametrize('lr', [0.1, 1])
187
187
  def test_rprop(nplus, nminus, lb, ub, lr):
188
- tz_fn = lambda p: tz.Modular(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
188
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
189
189
  torch_fn = lambda p: torch.optim.Rprop(p, lr, (nminus, nplus), (lb, ub))
190
190
  _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=30)
191
191
  _assert_identical_merge_closure(tz_fn, 'cpu', 30)
@@ -193,8 +193,8 @@ def test_rprop(nplus, nminus, lb, ub, lr):
193
193
 
194
194
  def test_adagrad():
195
195
  torch_fn = lambda p: torch.optim.Adagrad(p, 1)
196
- tz_fn = lambda p: tz.Modular(p, tz.m.Adagrad(), tz.m.LR(1))
197
- tz_fn2 = lambda p: tz.Modular(
196
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adagrad(), tz.m.LR(1))
197
+ tz_fn2 = lambda p: tz.Optimizer(
198
198
  p,
199
199
  tz.m.Div([tz.m.Pow(2), tz.m.AccumulateSum(), tz.m.Sqrt(), tz.m.Add(1e-10)]),
200
200
  )
@@ -212,15 +212,15 @@ def test_adagrad():
212
212
  @pytest.mark.parametrize('lr', [0.1, 1])
213
213
  def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
214
214
  torch_fn = lambda p: torch.optim.Adagrad(p, lr, initial_accumulator_value=initial_accumulator_value, eps=eps)
215
- tz_fn1 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
216
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
215
+ tz_fn1 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
216
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
217
217
  _assert_identical_opts([torch_fn, tz_fn1, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
218
218
 
219
219
 
220
220
  @pytest.mark.parametrize('tensorwise', [True, False])
221
221
  def test_graft(tensorwise):
222
- graft1 = lambda p: tz.Modular(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
- graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
222
+ graft1 = lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
+ graft2 = lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
224
224
  _assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
225
225
  for fn in [graft1, graft2]:
226
226
  if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)