torchzero 0.4.2__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. {torchzero-0.4.2 → torchzero-0.4.4}/PKG-INFO +1 -1
  2. {torchzero-0.4.2 → torchzero-0.4.4}/pyproject.toml +1 -1
  3. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_identical.py +1 -1
  4. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_opts.py +2 -2
  5. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/__init__.py +0 -0
  6. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/_minimize/methods.py +37 -32
  7. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/__init__.py +0 -0
  8. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/module.py +0 -0
  9. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/__init__.py +0 -0
  10. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/benchmark.py +0 -0
  11. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/eigh.py +2 -2
  12. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/linear_operator.py +0 -0
  13. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/matrix_power.py +0 -0
  14. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/orthogonalize.py +0 -0
  15. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/qr.py +0 -0
  16. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/solve.py +0 -0
  17. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/svd.py +0 -0
  18. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/torch_linalg.py +1 -1
  19. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/__init__.py +0 -0
  20. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/__init__.py +0 -0
  21. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adagrad.py +0 -0
  22. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adahessian.py +0 -0
  23. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adam.py +0 -0
  24. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adan.py +0 -0
  25. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
  26. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/aegd.py +0 -0
  27. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/esgd.py +0 -0
  28. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/ggt.py +1 -1
  29. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/lion.py +0 -0
  30. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/mars.py +0 -0
  31. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/matrix_momentum.py +0 -0
  32. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/msam.py +0 -0
  33. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/muon.py +0 -0
  34. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/natural_gradient.py +0 -0
  35. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/orthograd.py +0 -0
  36. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/rmsprop.py +0 -0
  37. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/rprop.py +0 -0
  38. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/sam.py +6 -1
  39. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/shampoo.py +1 -1
  40. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/soap.py +15 -2
  41. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/sophia_h.py +0 -0
  42. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/basis/ggt_basis.py +1 -1
  43. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/basis/soap_basis.py +0 -0
  44. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/__init__.py +0 -0
  45. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/clipping.py +0 -0
  46. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/ema_clipping.py +0 -0
  47. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/clipping/growth_clipping.py +0 -0
  48. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
  49. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/cg.py +0 -0
  50. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/__init__.py +0 -0
  51. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/coordinate_momentum.py +0 -0
  52. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/cubic_adam.py +0 -0
  53. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/curveball.py +0 -0
  54. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/dct.py +0 -0
  55. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/fft.py +0 -0
  56. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/gradmin.py +0 -0
  57. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/higher_order_newton.py +0 -0
  58. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/l_infinity.py +0 -0
  59. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/newton_solver.py +0 -0
  60. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/newtonnewton.py +0 -0
  61. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  62. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
  63. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/structural_projections.py +0 -0
  64. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/__init__.py +0 -0
  65. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/fdm.py +0 -0
  66. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
  67. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  68. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/rfdm.py +0 -0
  69. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/grad_approximation/spsa1.py +0 -0
  70. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/least_squares/__init__.py +0 -0
  71. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/least_squares/gn.py +0 -0
  72. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/__init__.py +0 -0
  73. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/_polyinterp.py +0 -0
  74. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/adaptive.py +0 -0
  75. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/backtracking.py +0 -0
  76. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/line_search.py +0 -0
  77. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/scipy.py +0 -0
  78. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/strong_wolfe.py +0 -0
  79. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/__init__.py +0 -0
  80. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/debug.py +0 -0
  81. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/escape.py +0 -0
  82. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/gradient_accumulation.py +0 -0
  83. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/misc.py +0 -0
  84. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/multistep.py +2 -3
  85. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/regularization.py +0 -0
  86. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/split.py +1 -1
  87. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/switch.py +0 -0
  88. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/__init__.py +0 -0
  89. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/averaging.py +0 -0
  90. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/cautious.py +0 -0
  91. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/momentum/momentum.py +9 -9
  92. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/__init__.py +0 -0
  93. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/accumulate.py +0 -0
  94. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/binary.py +0 -0
  95. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/higher_level.py +0 -0
  96. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/multi.py +0 -0
  97. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/reduce.py +0 -0
  98. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/unary.py +0 -0
  99. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/ops/utility.py +0 -0
  100. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/opt_utils.py +0 -0
  101. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/__init__.py +0 -0
  102. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/cast.py +0 -0
  103. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/galore.py +0 -0
  104. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/projections/projection.py +0 -0
  105. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/__init__.py +0 -0
  106. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
  107. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
  108. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lsr1.py +0 -0
  109. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
  110. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/__init__.py +0 -0
  111. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/inm.py +0 -0
  112. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/multipoint.py +0 -0
  113. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/newton.py +8 -1
  114. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/newton_cg.py +0 -0
  115. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/nystrom.py +0 -0
  116. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/smoothing/__init__.py +0 -0
  117. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/smoothing/laplacian.py +0 -0
  118. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/smoothing/sampling.py +0 -0
  119. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/step_size/__init__.py +1 -1
  120. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/step_size/adaptive.py +42 -0
  121. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/step_size/lr.py +0 -0
  122. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/termination/termination.py +2 -1
  123. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/__init__.py +0 -0
  124. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
  125. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/dogleg.py +0 -0
  126. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/levenberg_marquardt.py +0 -0
  127. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_cg.py +0 -0
  128. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_region.py +0 -0
  129. torchzero-0.4.4/torchzero/modules/weight_decay/__init__.py +8 -0
  130. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/weight_decay/weight_decay.py +84 -7
  131. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/wrappers/__init__.py +0 -0
  132. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
  133. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/__init__.py +0 -0
  134. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/root.py +0 -0
  135. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/utility/__init__.py +0 -0
  136. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/utility/split.py +0 -0
  137. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/directsearch.py +0 -0
  138. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/fcmaes.py +0 -0
  139. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/mads.py +0 -0
  140. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/optuna.py +0 -0
  141. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/__init__.py +0 -0
  142. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/compile.py +0 -0
  143. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/derivatives.py +0 -0
  144. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/numberlist.py +0 -0
  145. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/optimizer.py +0 -0
  146. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/optuna_tools.py +1 -1
  147. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/params.py +0 -0
  148. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/python_tools.py +0 -0
  149. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/torch_tools.py +0 -0
  150. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/PKG-INFO +1 -1
  151. torchzero-0.4.2/torchzero/modules/weight_decay/__init__.py +0 -2
  152. {torchzero-0.4.2 → torchzero-0.4.4}/setup.cfg +0 -0
  153. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_module.py +0 -0
  154. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_module_autograd.py +0 -0
  155. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_objective.py +0 -0
  156. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_tensorlist.py +0 -0
  157. {torchzero-0.4.2 → torchzero-0.4.4}/tests/test_utils_optimizer.py +0 -0
  158. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/_minimize/__init__.py +0 -0
  159. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/_minimize/minimize.py +0 -0
  160. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/chain.py +0 -0
  161. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/functional.py +0 -0
  162. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/modular.py +0 -0
  163. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/objective.py +0 -0
  164. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/reformulation.py +0 -0
  165. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/core/transform.py +0 -0
  166. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/linalg_utils.py +0 -0
  167. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/linalg/sketch.py +0 -0
  168. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/lre_optimizers.py +0 -0
  169. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/__init__.py +0 -0
  170. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/_psgd_utils.py +0 -0
  171. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd.py +0 -0
  172. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +0 -0
  173. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +0 -0
  174. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +0 -0
  175. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +0 -0
  176. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +0 -0
  177. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/basis/__init__.py +0 -0
  178. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/experimental/matrix_nag.py +0 -0
  179. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/line_search/interpolation.py +0 -0
  180. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/misc/homotopy.py +0 -0
  181. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/damping.py +0 -0
  182. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/quasi_newton/sg2.py +0 -0
  183. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/restarts/__init__.py +0 -0
  184. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/restarts/restars.py +0 -0
  185. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/ifn.py +0 -0
  186. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/second_order/rsn.py +0 -0
  187. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/termination/__init__.py +0 -0
  188. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/variance_reduction/__init__.py +0 -0
  189. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/variance_reduction/svrg.py +0 -0
  190. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/weight_decay/reinit.py +0 -0
  191. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/zeroth_order/__init__.py +0 -0
  192. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/modules/zeroth_order/cd.py +0 -0
  193. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/mbs.py +0 -0
  194. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/__init__.py +0 -0
  195. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/moors.py +0 -0
  196. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/nevergrad.py +0 -0
  197. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/nlopt.py +0 -0
  198. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/pybobyqa.py +0 -0
  199. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/__init__.py +0 -0
  200. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/basin_hopping.py +0 -0
  201. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/brute.py +0 -0
  202. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/differential_evolution.py +0 -0
  203. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/direct.py +0 -0
  204. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/dual_annealing.py +0 -0
  205. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/experimental.py +0 -0
  206. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/minimize.py +0 -0
  207. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/sgho.py +0 -0
  208. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/optim/wrappers/wrapper.py +0 -0
  209. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/benchmarks/__init__.py +0 -0
  210. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/benchmarks/logistic.py +0 -0
  211. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/metrics.py +0 -0
  212. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/tensorlist.py +0 -0
  213. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero/utils/thoad_tools.py +0 -0
  214. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/SOURCES.txt +0 -0
  215. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/dependency_links.txt +0 -0
  216. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/requires.txt +0 -0
  217. {torchzero-0.4.2 → torchzero-0.4.4}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.4.2
3
+ Version: 0.4.4
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.4.2"
16
+ version = "0.4.4"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -105,7 +105,7 @@ def test_adam(amsgrad):
105
105
  tz_fn_ops = lambda p: tz.Optimizer(
106
106
  p,
107
107
  tz.m.DivModules(
108
- tz.m.EMA(0.9, debiased=True),
108
+ tz.m.EMA(0.9, debias=True),
109
109
  [tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
110
110
  ))
111
111
  tz_fn_ops2 = lambda p: tz.Optimizer(
@@ -727,8 +727,8 @@ Adam = Run(
727
727
  )
728
728
  # ------------------------------ optimizers/soap ----------------------------- #
729
729
  SOAP = Run(
730
- func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(0.4)),
731
- sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
730
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(merge_small=True), tz.m.LR(0.4)),
731
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1, merge_small=True), tz.m.LR(1)),
732
732
  needs_closure=False,
733
733
  # merge and unmerge lrs are very different so need to test convergence separately somewhere
734
734
  func='rosen', steps=50, loss=4, merge_invariant=False,
@@ -14,82 +14,87 @@ from ..utils import tofloat
14
14
 
15
15
 
16
16
  def _get_method_from_str(method: str) -> list[Module]:
17
- method = ''.join(c for c in method.lower().strip() if c.isalnum())
17
+ stripped = ''.join(c for c in method.lower().strip() if c.isalnum())
18
18
 
19
- if method == "bfgs":
19
+ if stripped == "bfgs":
20
20
  return [m.RestartOnStuck(m.BFGS()), m.Backtracking()]
21
21
 
22
- if method == "lbfgs":
22
+ if stripped == "lbfgs":
23
23
  return [m.LBFGS(100), m.Backtracking()]
24
24
 
25
- if method == "newton":
25
+ if stripped == "newton":
26
26
  return [m.Newton(), m.Backtracking()]
27
27
 
28
- if method == "sfn":
28
+ if stripped == "sfn":
29
29
  return [m.Newton(eigval_fn=lambda x: x.abs().clip(min=1e-10)), m.Backtracking()]
30
30
 
31
- if method == "inm":
31
+ if stripped == "inm":
32
32
  return [m.ImprovedNewton(), m.Backtracking()]
33
33
 
34
- if method == 'crn':
34
+ if stripped == 'crn':
35
35
  return [m.CubicRegularization(m.Newton())]
36
36
 
37
- if method == "commondirections":
37
+ if stripped == "commondirections":
38
38
  return [m.SubspaceNewton(sketch_type='common_directions'), m.Backtracking()]
39
39
 
40
- if method == "trust":
40
+ if stripped == "trust":
41
41
  return [m.LevenbergMarquardt(m.Newton())]
42
42
 
43
- if method == "trustexact":
44
- return [m.TrustCG(m.Newton())]
45
-
46
- if method == "dogleg":
43
+ if stripped == "dogleg":
47
44
  return [m.Dogleg(m.Newton())]
48
45
 
49
- if method == "trustbfgs":
50
- return [m.LevenbergMarquardt(m.BFGS())]
46
+ if stripped == "trustbfgs":
47
+ return [m.RestartOnStuck(m.LevenbergMarquardt(m.BFGS()))]
51
48
 
52
- if method == "trustsr1":
53
- return [m.LevenbergMarquardt(m.SR1())]
49
+ if stripped == "trustsr1":
50
+ return [m.RestartOnStuck(m.LevenbergMarquardt(m.SR1()))]
54
51
 
55
- if method == "newtoncg":
52
+ if stripped == "newtoncg":
56
53
  return [m.NewtonCG(), m.Backtracking()]
57
54
 
58
- if method == "tn":
55
+ if stripped == "tn":
59
56
  return [m.NewtonCG(maxiter=10), m.Backtracking()]
60
57
 
61
- if method == "trustncg":
58
+ if stripped == "trustncg":
62
59
  return [m.NewtonCGSteihaug()]
63
60
 
64
- if method == "gd":
61
+ if stripped == "gd":
65
62
  return [m.Backtracking()]
66
63
 
67
- if method == "cg":
64
+ if stripped == "cg":
68
65
  return [m.FletcherReeves(), m.StrongWolfe(c2=0.1, fallback=True)]
69
66
 
70
- if method == "bb":
67
+ if stripped in ("shor", "shorr"):
68
+ return [m.ShorR(), m.StrongWolfe(c2=0.1, fallback=True)]
69
+
70
+ if stripped == "pgm":
71
+ return [m.ProjectedGradientMethod(), m.StrongWolfe(c2=0.1, fallback=True)]
72
+
73
+ if stripped == "bb":
71
74
  return [m.RestartOnStuck(m.BarzilaiBorwein())]
72
75
 
73
- if method == "bbstab":
76
+ if stripped == "bbstab":
74
77
  return [m.BBStab()]
75
78
 
76
- if method == "adgd":
79
+ if stripped == "adgd":
77
80
  return [m.AdGD()]
78
81
 
79
- if method in ("gn", "gaussnewton"):
82
+ if stripped in ("bd", "bolddriver"):
83
+ return [m.BoldDriver()]
84
+
85
+ if stripped in ("gn", "gaussnewton"):
80
86
  return [m.GaussNewton(), m.Backtracking()]
81
87
 
82
- if method == "rprop":
88
+ if stripped == "rprop":
83
89
  return [m.Rprop(alpha=1e-3)]
84
90
 
85
- if method == "lm":
91
+ if stripped == "lm":
86
92
  return [m.LevenbergMarquardt(m.GaussNewton())]
87
93
 
88
- if method == "mlm":
94
+ if stripped == "mlm":
89
95
  return [m.LevenbergMarquardt(m.GaussNewton(), y=1)]
90
96
 
91
- if method == "cd":
97
+ if stripped == "cd":
92
98
  return [m.CD(), m.ScipyMinimizeScalar(maxiter=8)]
93
99
 
94
-
95
- raise NotImplementedError(method)
100
+ raise NotImplementedError(stripped)
@@ -285,8 +285,8 @@ def rank1_eigh(v: torch.Tensor):
285
285
  vv = v.dot(v)
286
286
  norm = vv.sqrt().clip(min=torch.finfo(vv.dtype).tiny * 2)
287
287
 
288
- L = vv.unsqueeze(0) # (rank, )
289
- Q = v.unsqueeze(-1) / norm # (m, rank)
288
+ L = vv.unsqueeze(0) # (1, )
289
+ Q = v.unsqueeze(-1) / norm # (m, 1)
290
290
 
291
291
  return L, Q
292
292
 
@@ -46,7 +46,7 @@ def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Ten
46
46
  try:
47
47
  return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
48
48
 
49
- except torch.linalg.LinAlgError as e:
49
+ except (torch.linalg.LinAlgError, RuntimeError) as e:
50
50
  if not retry_float64: raise e
51
51
  dtype = A.dtype
52
52
  if dtype == torch.float64: raise e
@@ -130,7 +130,7 @@ class GGT(TensorTransform):
130
130
  step = state.get('step', 0)
131
131
  state['step'] = step + 1
132
132
 
133
- if step % update_freq == 0 :
133
+ if step % update_freq == 0:
134
134
 
135
135
  # compute new factors
136
136
  L = state.get("L", None)
@@ -1,7 +1,10 @@
1
+ from collections.abc import Mapping, Sequence
1
2
  from contextlib import nullcontext
3
+ from typing import Any
2
4
  import torch
5
+
3
6
  from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
4
- from ...core import Transform
7
+ from ...core import Transform, Objective
5
8
 
6
9
 
7
10
  class SAM(Transform):
@@ -126,6 +129,8 @@ class SAM(Transform):
126
129
 
127
130
  objective.closure = sam_closure
128
131
 
132
+ def apply_states(self, objective: Objective, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> Objective:
133
+ return objective
129
134
  # different class because defaults for SAM are bad for ASAM
130
135
  class ASAM(SAM):
131
136
  """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
@@ -31,7 +31,7 @@ def update_shampoo_preconditioner_(
31
31
  if reg != 0:
32
32
  accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
33
33
 
34
- if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
34
+ if matrix_power is None: matrix_power = -1 / max(grad.ndim * 2, 2)
35
35
  set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
36
36
 
37
37
  def apply_shampoo_preconditioner(
@@ -51,6 +51,7 @@ def project_back(tensor: torch.Tensor, Q: list[torch.Tensor| None]):
51
51
  return tensor
52
52
 
53
53
  # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
54
+ # this is only used once per accumulator to initialize it
54
55
  @torch.no_grad
55
56
  def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
56
57
  """
@@ -64,7 +65,19 @@ def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
64
65
  final.append(None)
65
66
  continue
66
67
 
67
- _, Q = torch_linalg.eigh(M + 1e-30 * torch.eye(M.shape[0], device=M.device), retry_float64=True)
68
+ if not torch.isfinite(M).all():
69
+ raise RuntimeError(f"Initial gradient for parameter {M.shape} has non-finite values.")
70
+
71
+ M_f64 = M.to(torch.float64) + 1e-30 * torch.eye(M.shape[0], device=M.device, dtype=torch.float64)
72
+ try:
73
+ _, Q_f64 = torch_linalg.eigh(M_f64)
74
+ except RuntimeError as e:
75
+ if M_f64.is_cpu: raise e
76
+ M_f64 = M_f64.cpu()
77
+ _, Q_f64 = torch_linalg.eigh(M_f64) # apparently there is a bug in CUDA eigh
78
+ Q_f64 = Q_f64.to(M.device)
79
+
80
+ Q = Q_f64.to(M.dtype)
68
81
 
69
82
  Q = torch.flip(Q, [1])
70
83
  final.append(Q)
@@ -156,7 +169,7 @@ class SOAP(TensorTransform):
156
169
  beta2: float = 0.95,
157
170
  shampoo_beta: float | None = 0.95,
158
171
  precond_freq: int = 10,
159
- merge_small: bool = True,
172
+ merge_small: bool = False,
160
173
  max_dim: int = 4096,
161
174
  precondition_1d: bool = True,
162
175
  eps: float = 1e-8,
@@ -111,7 +111,7 @@ class GGTBasis(TensorTransform):
111
111
  inner: Chainable | None = None,
112
112
  ):
113
113
  defaults = locals().copy()
114
- del defaults['self'], defaults['inner']
114
+ del defaults['self'], defaults['inner'], defaults["basis_opt"]
115
115
 
116
116
  super().__init__(defaults, concat_params=True, inner=inner)
117
117
  self.set_child("basis_opt", basis_opt)
@@ -154,8 +154,7 @@ class Online(Module):
154
154
  closure = objective.closure
155
155
  if closure is None: raise ValueError("Closure must be passed for Online")
156
156
 
157
- step = self.global_state.get('step', 0) + 1
158
- self.global_state['step'] = step
157
+ step = self.increment_counter("step", start = 0)
159
158
 
160
159
  params = TensorList(objective.params)
161
160
  p_cur = params.clone()
@@ -165,7 +164,7 @@ class Online(Module):
165
164
  var_c = objective.clone(clone_updates=False)
166
165
 
167
166
  # on 1st step just step and store previous params
168
- if step == 1:
167
+ if step == 0:
169
168
  p_prev.copy_(params)
170
169
 
171
170
  module.update(var_c)
@@ -53,11 +53,11 @@ _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.T
53
53
  Filter = _SingleFilter | Iterable[_SingleFilter]
54
54
 
55
55
  def _make_filter(filter: Filter):
56
- if callable(filter): return filter
57
56
  if isinstance(filter, torch.Tensor):
58
57
  return lambda x: x is filter
59
58
  if isinstance(filter, torch.nn.Module):
60
59
  return _make_filter(filter.parameters())
60
+ if callable(filter): return filter
61
61
 
62
62
  # iterable
63
63
  filters = [_make_filter(f) for f in filter]
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
- from ..opt_utils import debias, ema_
9
+ from ..opt_utils import debias as _debias, ema_
10
10
 
11
11
 
12
12
  class EMA(TensorTransform):
@@ -15,13 +15,13 @@ class EMA(TensorTransform):
15
15
  Args:
16
16
  momentum (float, optional): momentum (beta). Defaults to 0.9.
17
17
  dampening (float, optional): momentum dampening. Defaults to 0.
18
- debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
18
+ debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
19
  lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
20
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
21
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
22
  """
23
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
24
- defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
23
+ def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
24
+ defaults = dict(momentum=momentum,dampening=dampening,debias=debias,lerp=lerp,ema_init=ema_init)
25
25
  super().__init__(defaults, uses_grad=False)
26
26
 
27
27
  self.add_projected_keys("grad", "exp_avg")
@@ -30,7 +30,7 @@ class EMA(TensorTransform):
30
30
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
31
31
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
32
32
 
33
- debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
33
+ debias, lerp, ema_init = itemgetter('debias','lerp','ema_init')(settings[0])
34
34
 
35
35
  exp_avg = unpack_states(states, tensors, 'exp_avg',
36
36
  init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
@@ -38,7 +38,7 @@ class EMA(TensorTransform):
38
38
 
39
39
  exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
40
40
 
41
- if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
41
+ if debias: return _debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
42
42
  else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
43
43
 
44
44
 
@@ -49,14 +49,14 @@ class HeavyBall(EMA):
49
49
  Args:
50
50
  momentum (float, optional): momentum (beta). Defaults to 0.9.
51
51
  dampening (float, optional): momentum dampening. Defaults to 0.
52
- debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
52
+ debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
53
53
  lerp (bool, optional):
54
54
  whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
55
55
  ema_init (str, optional): initial values for the EMA, "zeros" or "update".
56
56
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
57
57
  """
58
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
59
- super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init)
58
+ def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
59
+ super().__init__(momentum=momentum, dampening=dampening, debias=debias, lerp=lerp, ema_init=ema_init)
60
60
 
61
61
  def nag_(
62
62
  tensors_: TensorList,