torchzero 0.4.3__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. {torchzero-0.4.3 → torchzero-0.4.4}/PKG-INFO +1 -1
  2. {torchzero-0.4.3 → torchzero-0.4.4}/pyproject.toml +1 -1
  3. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_opts.py +2 -2
  4. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/__init__.py +0 -0
  5. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/_minimize/methods.py +37 -32
  6. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/__init__.py +0 -0
  7. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/module.py +0 -0
  8. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/__init__.py +0 -0
  9. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/benchmark.py +0 -0
  10. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/eigh.py +2 -2
  11. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/linear_operator.py +0 -0
  12. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/matrix_power.py +0 -0
  13. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/orthogonalize.py +0 -0
  14. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/qr.py +0 -0
  15. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/solve.py +0 -0
  16. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/svd.py +0 -0
  17. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/torch_linalg.py +1 -1
  18. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/__init__.py +0 -0
  19. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/__init__.py +0 -0
  20. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adagrad.py +0 -0
  21. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adahessian.py +0 -0
  22. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adam.py +0 -0
  23. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adan.py +0 -0
  24. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
  25. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/aegd.py +0 -0
  26. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/esgd.py +0 -0
  27. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/ggt.py +1 -1
  28. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/lion.py +0 -0
  29. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/mars.py +0 -0
  30. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/matrix_momentum.py +0 -0
  31. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/msam.py +0 -0
  32. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/muon.py +0 -0
  33. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/natural_gradient.py +0 -0
  34. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/orthograd.py +0 -0
  35. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/rmsprop.py +0 -0
  36. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/rprop.py +0 -0
  37. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/sam.py +6 -1
  38. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/shampoo.py +1 -1
  39. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/soap.py +15 -2
  40. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/sophia_h.py +0 -0
  41. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/basis/ggt_basis.py +0 -0
  42. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/basis/soap_basis.py +0 -0
  43. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/__init__.py +0 -0
  44. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/clipping.py +0 -0
  45. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/ema_clipping.py +0 -0
  46. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/clipping/growth_clipping.py +0 -0
  47. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
  48. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/conjugate_gradient/cg.py +0 -0
  49. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/__init__.py +0 -0
  50. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/coordinate_momentum.py +0 -0
  51. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/cubic_adam.py +0 -0
  52. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/curveball.py +0 -0
  53. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/dct.py +0 -0
  54. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/fft.py +0 -0
  55. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/gradmin.py +0 -0
  56. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/higher_order_newton.py +0 -0
  57. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/l_infinity.py +0 -0
  58. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/newton_solver.py +0 -0
  59. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/newtonnewton.py +0 -0
  60. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  61. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
  62. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/structural_projections.py +0 -0
  63. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/__init__.py +0 -0
  64. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/fdm.py +0 -0
  65. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
  66. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  67. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/rfdm.py +0 -0
  68. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/grad_approximation/spsa1.py +0 -0
  69. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/least_squares/__init__.py +0 -0
  70. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/least_squares/gn.py +0 -0
  71. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/__init__.py +0 -0
  72. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/_polyinterp.py +0 -0
  73. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/adaptive.py +0 -0
  74. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/backtracking.py +0 -0
  75. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/line_search.py +0 -0
  76. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/scipy.py +0 -0
  77. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/strong_wolfe.py +0 -0
  78. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/__init__.py +0 -0
  79. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/debug.py +0 -0
  80. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/escape.py +0 -0
  81. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/gradient_accumulation.py +0 -0
  82. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/misc.py +0 -0
  83. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/multistep.py +2 -3
  84. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/regularization.py +0 -0
  85. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/split.py +1 -1
  86. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/switch.py +0 -0
  87. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/__init__.py +0 -0
  88. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/averaging.py +0 -0
  89. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/cautious.py +0 -0
  90. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/momentum/momentum.py +0 -0
  91. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/__init__.py +0 -0
  92. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/accumulate.py +0 -0
  93. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/binary.py +0 -0
  94. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/higher_level.py +0 -0
  95. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/multi.py +0 -0
  96. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/reduce.py +0 -0
  97. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/unary.py +0 -0
  98. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/ops/utility.py +0 -0
  99. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/opt_utils.py +0 -0
  100. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/__init__.py +0 -0
  101. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/cast.py +0 -0
  102. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/galore.py +0 -0
  103. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/projections/projection.py +0 -0
  104. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/__init__.py +0 -0
  105. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
  106. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
  107. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/lsr1.py +0 -0
  108. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
  109. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/__init__.py +0 -0
  110. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/inm.py +0 -0
  111. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/multipoint.py +0 -0
  112. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/newton.py +8 -1
  113. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/newton_cg.py +0 -0
  114. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/nystrom.py +0 -0
  115. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/smoothing/__init__.py +0 -0
  116. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/smoothing/laplacian.py +0 -0
  117. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/smoothing/sampling.py +0 -0
  118. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/step_size/__init__.py +1 -1
  119. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/step_size/adaptive.py +42 -0
  120. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/step_size/lr.py +0 -0
  121. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/termination/termination.py +2 -1
  122. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/__init__.py +0 -0
  123. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
  124. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/dogleg.py +0 -0
  125. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/levenberg_marquardt.py +0 -0
  126. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_cg.py +0 -0
  127. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/trust_region/trust_region.py +0 -0
  128. torchzero-0.4.4/torchzero/modules/weight_decay/__init__.py +8 -0
  129. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/weight_decay/weight_decay.py +84 -7
  130. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/wrappers/__init__.py +0 -0
  131. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
  132. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/__init__.py +0 -0
  133. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/root.py +0 -0
  134. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/utility/__init__.py +0 -0
  135. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/utility/split.py +0 -0
  136. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/directsearch.py +0 -0
  137. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/fcmaes.py +0 -0
  138. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/mads.py +0 -0
  139. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/optuna.py +0 -0
  140. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/__init__.py +0 -0
  141. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/compile.py +0 -0
  142. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/derivatives.py +0 -0
  143. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/numberlist.py +0 -0
  144. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/optimizer.py +0 -0
  145. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/optuna_tools.py +0 -0
  146. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/params.py +0 -0
  147. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/python_tools.py +0 -0
  148. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/torch_tools.py +0 -0
  149. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/PKG-INFO +1 -1
  150. torchzero-0.4.3/torchzero/modules/weight_decay/__init__.py +0 -2
  151. {torchzero-0.4.3 → torchzero-0.4.4}/setup.cfg +0 -0
  152. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_identical.py +0 -0
  153. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_module.py +0 -0
  154. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_module_autograd.py +0 -0
  155. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_objective.py +0 -0
  156. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_tensorlist.py +0 -0
  157. {torchzero-0.4.3 → torchzero-0.4.4}/tests/test_utils_optimizer.py +0 -0
  158. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/_minimize/__init__.py +0 -0
  159. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/_minimize/minimize.py +0 -0
  160. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/chain.py +0 -0
  161. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/functional.py +0 -0
  162. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/modular.py +0 -0
  163. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/objective.py +0 -0
  164. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/reformulation.py +0 -0
  165. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/core/transform.py +0 -0
  166. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/linalg_utils.py +0 -0
  167. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/linalg/sketch.py +0 -0
  168. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/lre_optimizers.py +0 -0
  169. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/__init__.py +0 -0
  170. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/_psgd_utils.py +0 -0
  171. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd.py +0 -0
  172. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_dense_newton.py +0 -0
  173. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_newton.py +0 -0
  174. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +0 -0
  175. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_newton.py +0 -0
  176. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +0 -0
  177. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/basis/__init__.py +0 -0
  178. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/experimental/matrix_nag.py +0 -0
  179. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/line_search/interpolation.py +0 -0
  180. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/misc/homotopy.py +0 -0
  181. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/damping.py +0 -0
  182. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/quasi_newton/sg2.py +0 -0
  183. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/restarts/__init__.py +0 -0
  184. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/restarts/restars.py +0 -0
  185. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/ifn.py +0 -0
  186. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/second_order/rsn.py +0 -0
  187. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/termination/__init__.py +0 -0
  188. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/variance_reduction/__init__.py +0 -0
  189. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/variance_reduction/svrg.py +0 -0
  190. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/weight_decay/reinit.py +0 -0
  191. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/zeroth_order/__init__.py +0 -0
  192. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/modules/zeroth_order/cd.py +0 -0
  193. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/mbs.py +0 -0
  194. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/__init__.py +0 -0
  195. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/moors.py +0 -0
  196. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/nevergrad.py +0 -0
  197. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/nlopt.py +0 -0
  198. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/pybobyqa.py +0 -0
  199. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/__init__.py +0 -0
  200. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/basin_hopping.py +0 -0
  201. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/brute.py +0 -0
  202. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/differential_evolution.py +0 -0
  203. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/direct.py +0 -0
  204. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/dual_annealing.py +0 -0
  205. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/experimental.py +0 -0
  206. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/minimize.py +0 -0
  207. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/scipy/sgho.py +0 -0
  208. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/optim/wrappers/wrapper.py +0 -0
  209. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/benchmarks/__init__.py +0 -0
  210. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/benchmarks/logistic.py +0 -0
  211. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/metrics.py +0 -0
  212. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/tensorlist.py +0 -0
  213. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero/utils/thoad_tools.py +0 -0
  214. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/SOURCES.txt +0 -0
  215. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/dependency_links.txt +0 -0
  216. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/requires.txt +0 -0
  217. {torchzero-0.4.3 → torchzero-0.4.4}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.4.3"
16
+ version = "0.4.4"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -727,8 +727,8 @@ Adam = Run(
727
727
  )
728
728
  # ------------------------------ optimizers/soap ----------------------------- #
729
729
  SOAP = Run(
730
- func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(0.4)),
731
- sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
730
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(merge_small=True), tz.m.LR(0.4)),
731
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1, merge_small=True), tz.m.LR(1)),
732
732
  needs_closure=False,
733
733
  # merge and unmerge lrs are very different so need to test convergence separately somewhere
734
734
  func='rosen', steps=50, loss=4, merge_invariant=False,
@@ -14,82 +14,87 @@ from ..utils import tofloat
14
14
 
15
15
 
16
16
  def _get_method_from_str(method: str) -> list[Module]:
17
- method = ''.join(c for c in method.lower().strip() if c.isalnum())
17
+ stripped = ''.join(c for c in method.lower().strip() if c.isalnum())
18
18
 
19
- if method == "bfgs":
19
+ if stripped == "bfgs":
20
20
  return [m.RestartOnStuck(m.BFGS()), m.Backtracking()]
21
21
 
22
- if method == "lbfgs":
22
+ if stripped == "lbfgs":
23
23
  return [m.LBFGS(100), m.Backtracking()]
24
24
 
25
- if method == "newton":
25
+ if stripped == "newton":
26
26
  return [m.Newton(), m.Backtracking()]
27
27
 
28
- if method == "sfn":
28
+ if stripped == "sfn":
29
29
  return [m.Newton(eigval_fn=lambda x: x.abs().clip(min=1e-10)), m.Backtracking()]
30
30
 
31
- if method == "inm":
31
+ if stripped == "inm":
32
32
  return [m.ImprovedNewton(), m.Backtracking()]
33
33
 
34
- if method == 'crn':
34
+ if stripped == 'crn':
35
35
  return [m.CubicRegularization(m.Newton())]
36
36
 
37
- if method == "commondirections":
37
+ if stripped == "commondirections":
38
38
  return [m.SubspaceNewton(sketch_type='common_directions'), m.Backtracking()]
39
39
 
40
- if method == "trust":
40
+ if stripped == "trust":
41
41
  return [m.LevenbergMarquardt(m.Newton())]
42
42
 
43
- if method == "trustexact":
44
- return [m.TrustCG(m.Newton())]
45
-
46
- if method == "dogleg":
43
+ if stripped == "dogleg":
47
44
  return [m.Dogleg(m.Newton())]
48
45
 
49
- if method == "trustbfgs":
50
- return [m.LevenbergMarquardt(m.BFGS())]
46
+ if stripped == "trustbfgs":
47
+ return [m.RestartOnStuck(m.LevenbergMarquardt(m.BFGS()))]
51
48
 
52
- if method == "trustsr1":
53
- return [m.LevenbergMarquardt(m.SR1())]
49
+ if stripped == "trustsr1":
50
+ return [m.RestartOnStuck(m.LevenbergMarquardt(m.SR1()))]
54
51
 
55
- if method == "newtoncg":
52
+ if stripped == "newtoncg":
56
53
  return [m.NewtonCG(), m.Backtracking()]
57
54
 
58
- if method == "tn":
55
+ if stripped == "tn":
59
56
  return [m.NewtonCG(maxiter=10), m.Backtracking()]
60
57
 
61
- if method == "trustncg":
58
+ if stripped == "trustncg":
62
59
  return [m.NewtonCGSteihaug()]
63
60
 
64
- if method == "gd":
61
+ if stripped == "gd":
65
62
  return [m.Backtracking()]
66
63
 
67
- if method == "cg":
64
+ if stripped == "cg":
68
65
  return [m.FletcherReeves(), m.StrongWolfe(c2=0.1, fallback=True)]
69
66
 
70
- if method == "bb":
67
+ if stripped in ("shor", "shorr"):
68
+ return [m.ShorR(), m.StrongWolfe(c2=0.1, fallback=True)]
69
+
70
+ if stripped == "pgm":
71
+ return [m.ProjectedGradientMethod(), m.StrongWolfe(c2=0.1, fallback=True)]
72
+
73
+ if stripped == "bb":
71
74
  return [m.RestartOnStuck(m.BarzilaiBorwein())]
72
75
 
73
- if method == "bbstab":
76
+ if stripped == "bbstab":
74
77
  return [m.BBStab()]
75
78
 
76
- if method == "adgd":
79
+ if stripped == "adgd":
77
80
  return [m.AdGD()]
78
81
 
79
- if method in ("gn", "gaussnewton"):
82
+ if stripped in ("bd", "bolddriver"):
83
+ return [m.BoldDriver()]
84
+
85
+ if stripped in ("gn", "gaussnewton"):
80
86
  return [m.GaussNewton(), m.Backtracking()]
81
87
 
82
- if method == "rprop":
88
+ if stripped == "rprop":
83
89
  return [m.Rprop(alpha=1e-3)]
84
90
 
85
- if method == "lm":
91
+ if stripped == "lm":
86
92
  return [m.LevenbergMarquardt(m.GaussNewton())]
87
93
 
88
- if method == "mlm":
94
+ if stripped == "mlm":
89
95
  return [m.LevenbergMarquardt(m.GaussNewton(), y=1)]
90
96
 
91
- if method == "cd":
97
+ if stripped == "cd":
92
98
  return [m.CD(), m.ScipyMinimizeScalar(maxiter=8)]
93
99
 
94
-
95
- raise NotImplementedError(method)
100
+ raise NotImplementedError(stripped)
@@ -285,8 +285,8 @@ def rank1_eigh(v: torch.Tensor):
285
285
  vv = v.dot(v)
286
286
  norm = vv.sqrt().clip(min=torch.finfo(vv.dtype).tiny * 2)
287
287
 
288
- L = vv.unsqueeze(0) # (rank, )
289
- Q = v.unsqueeze(-1) / norm # (m, rank)
288
+ L = vv.unsqueeze(0) # (1, )
289
+ Q = v.unsqueeze(-1) / norm # (m, 1)
290
290
 
291
291
  return L, Q
292
292
 
@@ -46,7 +46,7 @@ def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Ten
46
46
  try:
47
47
  return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
48
48
 
49
- except torch.linalg.LinAlgError as e:
49
+ except (torch.linalg.LinAlgError, RuntimeError) as e:
50
50
  if not retry_float64: raise e
51
51
  dtype = A.dtype
52
52
  if dtype == torch.float64: raise e
@@ -130,7 +130,7 @@ class GGT(TensorTransform):
130
130
  step = state.get('step', 0)
131
131
  state['step'] = step + 1
132
132
 
133
- if step % update_freq == 0 :
133
+ if step % update_freq == 0:
134
134
 
135
135
  # compute new factors
136
136
  L = state.get("L", None)
@@ -1,7 +1,10 @@
1
+ from collections.abc import Mapping, Sequence
1
2
  from contextlib import nullcontext
3
+ from typing import Any
2
4
  import torch
5
+
3
6
  from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
4
- from ...core import Transform
7
+ from ...core import Transform, Objective
5
8
 
6
9
 
7
10
  class SAM(Transform):
@@ -126,6 +129,8 @@ class SAM(Transform):
126
129
 
127
130
  objective.closure = sam_closure
128
131
 
132
+ def apply_states(self, objective: Objective, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> Objective:
133
+ return objective
129
134
  # different class because defaults for SAM are bad for ASAM
130
135
  class ASAM(SAM):
131
136
  """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
@@ -31,7 +31,7 @@ def update_shampoo_preconditioner_(
31
31
  if reg != 0:
32
32
  accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
33
33
 
34
- if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
34
+ if matrix_power is None: matrix_power = -1 / max(grad.ndim * 2, 2)
35
35
  set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
36
36
 
37
37
  def apply_shampoo_preconditioner(
@@ -51,6 +51,7 @@ def project_back(tensor: torch.Tensor, Q: list[torch.Tensor| None]):
51
51
  return tensor
52
52
 
53
53
  # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
54
+ # this is only used once per accumulator to initialize it
54
55
  @torch.no_grad
55
56
  def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
56
57
  """
@@ -64,7 +65,19 @@ def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
64
65
  final.append(None)
65
66
  continue
66
67
 
67
- _, Q = torch_linalg.eigh(M + 1e-30 * torch.eye(M.shape[0], device=M.device), retry_float64=True)
68
+ if not torch.isfinite(M).all():
69
+ raise RuntimeError(f"Initial gradient for parameter {M.shape} has non-finite values.")
70
+
71
+ M_f64 = M.to(torch.float64) + 1e-30 * torch.eye(M.shape[0], device=M.device, dtype=torch.float64)
72
+ try:
73
+ _, Q_f64 = torch_linalg.eigh(M_f64)
74
+ except RuntimeError as e:
75
+ if M_f64.is_cpu: raise e
76
+ M_f64 = M_f64.cpu()
77
+ _, Q_f64 = torch_linalg.eigh(M_f64) # apparently there is a bug in CUDA eigh
78
+ Q_f64 = Q_f64.to(M.device)
79
+
80
+ Q = Q_f64.to(M.dtype)
68
81
 
69
82
  Q = torch.flip(Q, [1])
70
83
  final.append(Q)
@@ -156,7 +169,7 @@ class SOAP(TensorTransform):
156
169
  beta2: float = 0.95,
157
170
  shampoo_beta: float | None = 0.95,
158
171
  precond_freq: int = 10,
159
- merge_small: bool = True,
172
+ merge_small: bool = False,
160
173
  max_dim: int = 4096,
161
174
  precondition_1d: bool = True,
162
175
  eps: float = 1e-8,
@@ -154,8 +154,7 @@ class Online(Module):
154
154
  closure = objective.closure
155
155
  if closure is None: raise ValueError("Closure must be passed for Online")
156
156
 
157
- step = self.global_state.get('step', 0) + 1
158
- self.global_state['step'] = step
157
+ step = self.increment_counter("step", start = 0)
159
158
 
160
159
  params = TensorList(objective.params)
161
160
  p_cur = params.clone()
@@ -165,7 +164,7 @@ class Online(Module):
165
164
  var_c = objective.clone(clone_updates=False)
166
165
 
167
166
  # on 1st step just step and store previous params
168
- if step == 1:
167
+ if step == 0:
169
168
  p_prev.copy_(params)
170
169
 
171
170
  module.update(var_c)
@@ -53,11 +53,11 @@ _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.T
53
53
  Filter = _SingleFilter | Iterable[_SingleFilter]
54
54
 
55
55
  def _make_filter(filter: Filter):
56
- if callable(filter): return filter
57
56
  if isinstance(filter, torch.Tensor):
58
57
  return lambda x: x is filter
59
58
  if isinstance(filter, torch.nn.Module):
60
59
  return _make_filter(filter.parameters())
60
+ if callable(filter): return filter
61
61
 
62
62
  # iterable
63
63
  filters = [_make_filter(f) for f in filter]
@@ -44,7 +44,14 @@ def _newton_update_state_(
44
44
 
45
45
  # if any args require eigendecomp, we don't need H or H_inv, we store factors
46
46
  if any(i is not None for i in [eigval_fn, eigv_tol, truncate]):
47
- L, Q = torch_linalg.eigh(H, retry_float64=True)
47
+ try:
48
+ state.pop("H", None)
49
+ L, Q = torch_linalg.eigh(H, retry_float64=True)
50
+ except torch.linalg.LinAlgError:
51
+ state.pop("L",None); state.pop("Q",None)
52
+ state["H"] = H
53
+ return
54
+
48
55
  if eigval_fn is not None: L = eigval_fn(L)
49
56
  L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eigv_tol)
50
57
  state["L"] = L
@@ -1,2 +1,2 @@
1
1
  from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
2
- from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD
2
+ from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD, BoldDriver