torchzero 0.3.13__tar.gz → 0.3.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. {torchzero-0.3.13 → torchzero-0.3.15}/PKG-INFO +1 -1
  2. {torchzero-0.3.13 → torchzero-0.3.15}/pyproject.toml +1 -1
  3. {torchzero-0.3.13 → torchzero-0.3.15}/tests/test_opts.py +4 -10
  4. torchzero-0.3.15/torchzero/core/__init__.py +5 -0
  5. torchzero-0.3.15/torchzero/core/chain.py +50 -0
  6. torchzero-0.3.15/torchzero/core/functional.py +37 -0
  7. torchzero-0.3.15/torchzero/core/modular.py +237 -0
  8. torchzero-0.3.15/torchzero/core/module.py +327 -0
  9. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/core/reformulation.py +3 -1
  10. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/core/transform.py +7 -5
  11. torchzero-0.3.15/torchzero/core/var.py +376 -0
  12. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/__init__.py +0 -1
  13. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/adahessian.py +2 -2
  14. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/esgd.py +2 -2
  15. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/matrix_momentum.py +1 -1
  16. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/sophia_h.py +2 -2
  17. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/conjugate_gradient/cg.py +16 -16
  18. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/__init__.py +1 -0
  19. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/newtonnewton.py +5 -5
  20. torchzero-0.3.15/torchzero/modules/experimental/spsa1.py +93 -0
  21. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/functional.py +7 -0
  22. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/grad_approximation/__init__.py +1 -1
  23. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/grad_approximation/forward_gradient.py +2 -5
  24. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/grad_approximation/rfdm.py +27 -110
  25. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/__init__.py +1 -1
  26. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/_polyinterp.py +3 -1
  27. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/adaptive.py +3 -3
  28. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/backtracking.py +1 -1
  29. torchzero-0.3.15/torchzero/modules/line_search/interpolation.py +160 -0
  30. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/line_search.py +11 -20
  31. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/scipy.py +15 -3
  32. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/line_search/strong_wolfe.py +3 -5
  33. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/misc.py +2 -2
  34. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/multistep.py +13 -13
  35. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/quasi_newton/__init__.py +2 -0
  36. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  37. torchzero-0.3.15/torchzero/modules/quasi_newton/sg2.py +292 -0
  38. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/restarts/restars.py +5 -4
  39. torchzero-0.3.15/torchzero/modules/second_order/__init__.py +7 -0
  40. torchzero-0.3.15/torchzero/modules/second_order/ifn.py +89 -0
  41. torchzero-0.3.15/torchzero/modules/second_order/inm.py +105 -0
  42. torchzero-0.3.15/torchzero/modules/second_order/newton.py +293 -0
  43. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/second_order/newton_cg.py +86 -110
  44. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/second_order/nystrom.py +1 -1
  45. torchzero-0.3.15/torchzero/modules/second_order/rsn.py +227 -0
  46. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  47. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/trust_region/trust_cg.py +6 -4
  48. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/wrappers/optim_wrapper.py +49 -42
  49. torchzero-0.3.15/torchzero/modules/zeroth_order/__init__.py +1 -0
  50. torchzero-0.3.15/torchzero/modules/zeroth_order/cd.py +122 -0
  51. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/derivatives.py +19 -19
  52. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/linear_operator.py +50 -2
  53. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/optimizer.py +2 -2
  54. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/python_tools.py +1 -0
  55. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero.egg-info/PKG-INFO +1 -1
  56. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero.egg-info/SOURCES.txt +11 -2
  57. torchzero-0.3.13/torchzero/core/__init__.py +0 -2
  58. torchzero-0.3.13/torchzero/core/module.py +0 -914
  59. torchzero-0.3.13/torchzero/modules/higher_order/__init__.py +0 -1
  60. torchzero-0.3.13/torchzero/modules/second_order/__init__.py +0 -4
  61. torchzero-0.3.13/torchzero/modules/second_order/newton.py +0 -383
  62. torchzero-0.3.13/torchzero/modules/zeroth_order/__init__.py +0 -1
  63. torchzero-0.3.13/torchzero/modules/zeroth_order/cd.py +0 -359
  64. {torchzero-0.3.13 → torchzero-0.3.15}/setup.cfg +0 -0
  65. {torchzero-0.3.13 → torchzero-0.3.15}/tests/test_identical.py +0 -0
  66. {torchzero-0.3.13 → torchzero-0.3.15}/tests/test_module.py +0 -0
  67. {torchzero-0.3.13 → torchzero-0.3.15}/tests/test_tensorlist.py +0 -0
  68. {torchzero-0.3.13 → torchzero-0.3.15}/tests/test_utils_optimizer.py +0 -0
  69. {torchzero-0.3.13 → torchzero-0.3.15}/tests/test_vars.py +0 -0
  70. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/__init__.py +0 -0
  71. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/__init__.py +0 -0
  72. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/adagrad.py +0 -0
  73. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/adam.py +0 -0
  74. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/adan.py +0 -0
  75. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/adaptive_heavyball.py +0 -0
  76. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/aegd.py +0 -0
  77. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/lion.py +0 -0
  78. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/lmadagrad.py +0 -0
  79. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/mars.py +0 -0
  80. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/msam.py +0 -0
  81. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/muon.py +0 -0
  82. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/natural_gradient.py +0 -0
  83. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/orthograd.py +0 -0
  84. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/rmsprop.py +0 -0
  85. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/rprop.py +0 -0
  86. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/sam.py +0 -0
  87. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/shampoo.py +0 -0
  88. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/adaptive/soap.py +0 -0
  89. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/clipping/__init__.py +0 -0
  90. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/clipping/clipping.py +0 -0
  91. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/clipping/ema_clipping.py +0 -0
  92. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/clipping/growth_clipping.py +0 -0
  93. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/conjugate_gradient/__init__.py +0 -0
  94. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/curveball.py +0 -0
  95. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/dct.py +0 -0
  96. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/fft.py +0 -0
  97. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/gradmin.py +0 -0
  98. {torchzero-0.3.13/torchzero/modules/higher_order → torchzero-0.3.15/torchzero/modules/experimental}/higher_order_newton.py +0 -0
  99. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/l_infinity.py +0 -0
  100. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/momentum.py +0 -0
  101. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/newton_solver.py +0 -0
  102. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  103. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/scipy_newton_cg.py +0 -0
  104. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/experimental/structural_projections.py +0 -0
  105. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/grad_approximation/fdm.py +0 -0
  106. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  107. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/least_squares/__init__.py +0 -0
  108. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/least_squares/gn.py +0 -0
  109. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/__init__.py +0 -0
  110. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/debug.py +0 -0
  111. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/escape.py +0 -0
  112. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/gradient_accumulation.py +0 -0
  113. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/homotopy.py +0 -0
  114. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/regularization.py +0 -0
  115. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/split.py +0 -0
  116. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/misc/switch.py +0 -0
  117. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/momentum/__init__.py +0 -0
  118. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/momentum/averaging.py +0 -0
  119. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/momentum/cautious.py +0 -0
  120. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/momentum/momentum.py +0 -0
  121. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/__init__.py +0 -0
  122. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/accumulate.py +0 -0
  123. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/binary.py +0 -0
  124. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/higher_level.py +0 -0
  125. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/multi.py +0 -0
  126. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/reduce.py +0 -0
  127. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/unary.py +0 -0
  128. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/ops/utility.py +0 -0
  129. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/projections/__init__.py +0 -0
  130. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/projections/cast.py +0 -0
  131. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/projections/galore.py +0 -0
  132. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/projections/projection.py +0 -0
  133. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/quasi_newton/damping.py +0 -0
  134. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +0 -0
  135. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
  136. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/quasi_newton/lsr1.py +0 -0
  137. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/restarts/__init__.py +0 -0
  138. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/second_order/multipoint.py +0 -0
  139. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/smoothing/__init__.py +0 -0
  140. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/smoothing/laplacian.py +0 -0
  141. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/smoothing/sampling.py +0 -0
  142. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/step_size/__init__.py +0 -0
  143. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/step_size/adaptive.py +0 -0
  144. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/step_size/lr.py +0 -0
  145. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/termination/__init__.py +0 -0
  146. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/termination/termination.py +0 -0
  147. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/trust_region/__init__.py +0 -0
  148. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/trust_region/cubic_regularization.py +0 -0
  149. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/trust_region/dogleg.py +0 -0
  150. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/trust_region/trust_region.py +0 -0
  151. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/variance_reduction/__init__.py +0 -0
  152. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/variance_reduction/svrg.py +0 -0
  153. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/weight_decay/__init__.py +0 -0
  154. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/weight_decay/weight_decay.py +0 -0
  155. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/modules/wrappers/__init__.py +0 -0
  156. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/__init__.py +0 -0
  157. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/root.py +0 -0
  158. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/utility/__init__.py +0 -0
  159. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/utility/split.py +0 -0
  160. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/__init__.py +0 -0
  161. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/directsearch.py +0 -0
  162. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/fcmaes.py +0 -0
  163. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/mads.py +0 -0
  164. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/nevergrad.py +0 -0
  165. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/nlopt.py +0 -0
  166. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/optuna.py +0 -0
  167. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/optim/wrappers/scipy.py +0 -0
  168. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/__init__.py +0 -0
  169. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/compile.py +0 -0
  170. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/__init__.py +0 -0
  171. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/benchmark.py +0 -0
  172. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/matrix_funcs.py +0 -0
  173. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/orthogonalize.py +0 -0
  174. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/qr.py +0 -0
  175. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/solve.py +0 -0
  176. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/linalg/svd.py +0 -0
  177. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/metrics.py +0 -0
  178. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/numberlist.py +0 -0
  179. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/ops.py +0 -0
  180. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/optuna_tools.py +0 -0
  181. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/params.py +0 -0
  182. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/tensorlist.py +0 -0
  183. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero/utils/torch_tools.py +0 -0
  184. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero.egg-info/dependency_links.txt +0 -0
  185. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero.egg-info/requires.txt +0 -0
  186. {torchzero-0.3.13 → torchzero-0.3.15}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.13
3
+ Version: 0.3.15
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.13"
16
+ version = "0.3.15"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -400,13 +400,6 @@ RandomizedFDM_4samples = Run(
400
400
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
401
401
  sphere_steps=100, sphere_loss=400,
402
402
  )
403
- RandomizedFDM_4samples_lerp = Run(
404
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.99, seed=0), tz.m.LR(0.1)),
405
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.9, seed=0), tz.m.LR(0.001)),
406
- needs_closure=True,
407
- func='booth', steps=50, loss=1e-5, merge_invariant=True,
408
- sphere_steps=100, sphere_loss=505,
409
- )
410
403
  RandomizedFDM_4samples_no_pre_generate = Run(
411
404
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
412
405
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
@@ -762,6 +755,7 @@ SR1 = Run(
762
755
  sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
763
756
  needs_closure=True,
764
757
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
758
+ # this reaches 1e-13 on github so don't change to 0
765
759
  sphere_steps=10, sphere_loss=0,
766
760
  )
767
761
  SSVM = Run(
@@ -813,7 +807,7 @@ NewtonCG = Run(
813
807
  func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
814
808
  sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
815
809
  needs_closure=True,
816
- func='rosen', steps=20, loss=1e-7, merge_invariant=True,
810
+ func='rosen', steps=20, loss=1e-10, merge_invariant=True,
817
811
  sphere_steps=2, sphere_loss=3e-4,
818
812
  )
819
813
 
@@ -879,8 +873,8 @@ SophiaH = Run(
879
873
 
880
874
  # -------------------------- higher_order ------------------------- #
881
875
  HigherOrderNewton = Run(
882
- func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
883
- sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
876
+ func_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
877
+ sphere_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
884
878
  needs_closure=True,
885
879
  func='rosen', steps=1, loss=2e-10, merge_invariant=True,
886
880
  sphere_steps=1, sphere_loss=1e-10,
@@ -0,0 +1,5 @@
1
+ from .chain import Chain, maybe_chain
2
+ from .modular import Modular
3
+ from .module import Chainable, Module
4
+ from .transform import Target, TensorwiseTransform, Transform, apply_transform
5
+ from .var import Var
@@ -0,0 +1,50 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ..utils.python_tools import flatten
4
+ from .module import Module, Chainable
5
+
6
+
7
+ class Chain(Module):
8
+ """Chain of modules, mostly used internally"""
9
+ def __init__(self, *modules: Module | Iterable[Module]):
10
+ super().__init__()
11
+ flat_modules: list[Module] = flatten(modules)
12
+ for i, module in enumerate(flat_modules):
13
+ self.set_child(f'module_{i}', module)
14
+
15
+ def update(self, var):
16
+ # note here that `update` and `apply` shouldn't be used directly
17
+ # as it will update all modules, and then apply all modules
18
+ # it is used in specific cases like Chain as trust region hessian module
19
+ for i in range(len(self.children)):
20
+ self.children[f'module_{i}'].update(var)
21
+ if var.stop: break
22
+ return var
23
+
24
+ def apply(self, var):
25
+ for i in range(len(self.children)):
26
+ var = self.children[f'module_{i}'].apply(var)
27
+ if var.stop: break
28
+ return var
29
+
30
+ def step(self, var):
31
+ for i in range(len(self.children)):
32
+ var = self.children[f'module_{i}'].step(var)
33
+ if var.stop: break
34
+ return var
35
+
36
+ def __repr__(self):
37
+ s = self.__class__.__name__
38
+ if self.children:
39
+ if s == 'Chain': s = 'C' # to shorten it
40
+ s = f'{s}({", ".join(str(m) for m in self.children.values())})'
41
+ return s
42
+
43
+ def maybe_chain(*modules: Chainable) -> Module:
44
+ """Returns a single module directly if only one is provided, otherwise wraps them in a :code:`Chain`."""
45
+ flat_modules: list[Module] = flatten(modules)
46
+ if len(flat_modules) == 1:
47
+ return flat_modules[0]
48
+ return Chain(*flat_modules)
49
+
50
+
@@ -0,0 +1,37 @@
1
+ from collections.abc import Sequence
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from .module import Module
6
+ from .var import Var
7
+
8
+
9
+ def step(var: "Var", modules: "Sequence[Module]",) -> "Var":
10
+ """steps with ``modules`` and returns modified ``var``, doesn't update parameters.
11
+
12
+ Args:
13
+ var (Var): Var object.
14
+ modules (Sequence[Module]): sequence of modules to step with.
15
+
16
+ Returns:
17
+ Var: modified Var
18
+ """
19
+ # n_modules = len(modules)
20
+ # if n_modules == 0: return var.clone(clone_update=False)
21
+ # last_module = modules[-1]
22
+ # last_lr = last_module.defaults.get('lr', None)
23
+
24
+ # step
25
+ for i, module in enumerate(modules):
26
+ if i!=0: var = var.clone(clone_update=False)
27
+
28
+ # last module, or next to last module before lr
29
+ # if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
30
+ # if len(module.children) != 0 or is_nested: var.nested_is_last = True
31
+ # else: var.is_last = True
32
+ # if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
33
+
34
+ var = module.step(var)
35
+ if var.stop: break
36
+
37
+ return var
@@ -0,0 +1,237 @@
1
+
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from collections import ChainMap, defaultdict
5
+ from collections.abc import Callable, Iterable, MutableMapping, Sequence
6
+ from operator import itemgetter
7
+ from typing import TYPE_CHECKING, Any, Literal, cast, final, overload
8
+
9
+ import torch
10
+
11
+ from ..utils import (
12
+ Init,
13
+ ListLike,
14
+ Params,
15
+ _make_param_groups,
16
+ get_state_vals,
17
+ vec_to_tensors,
18
+ )
19
+ from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
20
+ from ..utils.linalg.linear_operator import LinearOperator
21
+ from ..utils.python_tools import flatten
22
+ from .module import Chainable, Module
23
+ from .var import Var
24
+ from .functional import step
25
+
26
+ class _EvalCounterClosure:
27
+ """keeps track of how many times closure has been evaluated, and sets closure return"""
28
+ __slots__ = ("modular", "closure")
29
+ def __init__(self, modular: "Modular", closure):
30
+ self.modular = modular
31
+ self.closure = closure
32
+
33
+ def __call__(self, *args, **kwargs):
34
+ if self.closure is None:
35
+ raise RuntimeError("One of the modules requires closure to be passed to the step method")
36
+
37
+ v = self.closure(*args, **kwargs)
38
+
39
+ # set closure return on 1st evaluation
40
+ if self.modular._closure_return is None:
41
+ self.modular._closure_return = v
42
+
43
+ self.modular.num_evaluations += 1
44
+ return v
45
+
46
+
47
+ def unroll_modules(*modules: Chainable) -> list[Module]:
48
+ unrolled = []
49
+
50
+ for m in modules:
51
+ if isinstance(m, Module):
52
+ unrolled.append(m)
53
+ unrolled.extend(unroll_modules(list(m.children.values())))
54
+ else:
55
+ unrolled.extend(unroll_modules(*m))
56
+
57
+ return unrolled
58
+
59
+
60
+ # have to inherit from Modular to support lr schedulers
61
+ # although Accelerate doesn't work due to converting param_groups to a dict
62
+ class Modular(torch.optim.Optimizer):
63
+ """Chains multiple modules into an optimizer.
64
+
65
+ Args:
66
+ params (Params | torch.nn.Module): An iterable of parameters to optimize
67
+ (typically `model.parameters()`), an iterable of parameter group dicts,
68
+ or a `torch.nn.Module` instance.
69
+ *modules (Module): A sequence of `Module` instances that define the
70
+ optimization algorithm steps.
71
+ """
72
+ # this is specifically for lr schedulers
73
+ param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
74
+
75
+ def __init__(self, params: Params | torch.nn.Module, *modules: Module):
76
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
77
+ self.model: torch.nn.Module | None = None
78
+ """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
79
+ if isinstance(params, torch.nn.Module):
80
+ self.model = params
81
+ params = params.parameters()
82
+
83
+ self.modules = modules
84
+ """Top-level modules providedduring initialization."""
85
+
86
+ self.unrolled_modules = unroll_modules(self.modules)
87
+ """A flattened list of all modules including all children."""
88
+
89
+ param_groups = _make_param_groups(params, differentiable=False)
90
+ self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
91
+ """Maps each parameter tensor to a list of per-module global settings.
92
+ Each element in the list is ChainDict's 2nd map of a module."""
93
+
94
+ # make sure there is no more than a single learning rate module
95
+ lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
96
+ if len(lr_modules) > 1:
97
+ warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
98
+
99
+ # iterate over all per-parameter settings overrides and check if they are applied at most once
100
+ for group in param_groups:
101
+ for k in group:
102
+ if k in ('params', 'lr'): continue
103
+ modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
104
+ if len(modules_with_k) > 1:
105
+ warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
106
+
107
+ # defaults for schedulers
108
+ defaults = {}
109
+ for m in self.unrolled_modules: defaults.update(m.defaults)
110
+ super().__init__(param_groups, defaults=defaults)
111
+
112
+ # note - this is what super().__init__(param_groups, defaults=defaults) does:
113
+
114
+ # self.defaults = defaults
115
+ # for param_group in param_groups:
116
+ # self.add_param_group(param_group)
117
+
118
+ # add_param_group adds a ChainMap where defaults are lowest priority,
119
+ # and entries specifed in param_groups or scheduler are higher priority.
120
+ # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
121
+ # in each module, settings passed to that module by calling set_param_groups are highest priority
122
+
123
+ self.current_step = 0
124
+ """global step counter for the optimizer."""
125
+
126
+ self.num_evaluations = 0
127
+ """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
128
+
129
+ # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
130
+ # we want to return original loss so this attribute is used
131
+ self._closure_return = None
132
+ """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
133
+
134
+ self.attrs = {}
135
+ """custom attributes that can be set by modules, for example EMA of weights or best so far"""
136
+
137
+ self.should_terminate = False
138
+ """is set to True by termination criteria modules."""
139
+
140
+ def add_param_group(self, param_group: dict[str, Any]):
141
+ proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
142
+ self.param_groups.append(ChainMap(proc_param_group, self.defaults))
143
+ # setting param_group[key] = value sets it to first map (the `proc_param_group`).
144
+ # therefore lr schedulers override defaults, but not settings passed to individual modules
145
+ # by `set_param_groups` .
146
+
147
+ for p in proc_param_group['params']:
148
+ # updates global per-parameter setting overrides (medium priority)
149
+ self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
150
+
151
+ def state_dict(self):
152
+ all_params = [p for g in self.param_groups for p in g['params']]
153
+ id_to_idx = {id(p): i for i,p in enumerate(all_params)}
154
+
155
+ groups = []
156
+ for g in self.param_groups:
157
+ g = g.copy()
158
+ g['params'] = [id_to_idx[id(p)] for p in g['params']]
159
+ groups.append(g)
160
+
161
+ state_dict = {
162
+ "idx_to_id": {v:k for k,v in id_to_idx.items()},
163
+ "params": all_params,
164
+ "groups": groups,
165
+ "defaults": self.defaults,
166
+ "modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
167
+ }
168
+ return state_dict
169
+
170
+ def load_state_dict(self, state_dict: dict):
171
+ self.defaults.clear()
172
+ self.defaults.update(state_dict['defaults'])
173
+
174
+ idx_to_param = dict(enumerate(state_dict['params']))
175
+ groups = []
176
+ for g in state_dict['groups']:
177
+ g = g.copy()
178
+ g['params'] = [idx_to_param[p] for p in g['params']]
179
+ groups.append(g)
180
+
181
+ self.param_groups.clear()
182
+ for group in groups:
183
+ self.add_param_group(group)
184
+
185
+ id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
186
+ for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
187
+ m._load_state_dict(sd, id_to_tensor)
188
+
189
+
190
+ def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
191
+ # clear closure return from previous step
192
+ self._closure_return = None
193
+
194
+ # propagate global per-parameter setting overrides
195
+ for g in self.param_groups:
196
+ settings = dict(g.maps[0]) # ignore defaults
197
+ params = settings.pop('params')
198
+ if not settings: continue
199
+
200
+ for p in params:
201
+ if not p.requires_grad: continue
202
+ for map in self._per_parameter_global_settings[p]: map.update(settings)
203
+
204
+ # create var
205
+ params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
206
+ var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
207
+
208
+ # if closure is None, assume backward has been called and gather grads
209
+ if closure is None:
210
+ var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
211
+ self.num_evaluations += 1
212
+
213
+ if len(self.modules) == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
214
+
215
+ # step
216
+ var = step(var, self.modules)
217
+
218
+ # apply update
219
+ if not var.skip_update:
220
+ with torch.no_grad():
221
+ torch._foreach_sub_(params, var.get_update())
222
+
223
+ # update attributes
224
+ self.attrs.update(var.attrs)
225
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
226
+
227
+ # hooks
228
+ for hook in var.post_step_hooks:
229
+ hook(self, var)
230
+
231
+ self.current_step += 1
232
+ #return var.loss if var.loss is not None else var.loss_approx
233
+ return self._closure_return
234
+
235
+ def __repr__(self):
236
+ return f'Modular({", ".join(str(m) for m in self.modules)})'
237
+