torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_identical.py CHANGED
@@ -97,30 +97,30 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
99
  torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
101
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
102
- tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
103
- tz_fn4 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
104
- tz_fn5 = lambda p: tz.Modular(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
105
- tz_fn_ops = lambda p: tz.Modular(
100
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad))
101
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
102
+ tz_fn3 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
103
+ tz_fn4 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
104
+ tz_fn5 = lambda p: tz.Optimizer(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
105
+ tz_fn_ops = lambda p: tz.Optimizer(
106
106
  p,
107
107
  tz.m.DivModules(
108
108
  tz.m.EMA(0.9, debiased=True),
109
109
  [tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
110
110
  ))
111
- tz_fn_ops2 = lambda p: tz.Modular(
111
+ tz_fn_ops2 = lambda p: tz.Optimizer(
112
112
  p,
113
113
  tz.m.DivModules(
114
114
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
115
115
  [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Debias2(beta=0.999), tz.m.Add(1e-8)]
116
116
  ))
117
- tz_fn_ops3 = lambda p: tz.Modular(
117
+ tz_fn_ops3 = lambda p: tz.Optimizer(
118
118
  p,
119
119
  tz.m.DivModules(
120
120
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
121
121
  [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
122
122
  ))
123
- tz_fn_ops4 = lambda p: tz.Modular(
123
+ tz_fn_ops4 = lambda p: tz.Optimizer(
124
124
  p,
125
125
  tz.m.DivModules(
126
126
  [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
@@ -145,19 +145,19 @@ def test_adam(amsgrad):
145
145
  @pytest.mark.parametrize('amsgrad', [True, False])
146
146
  @pytest.mark.parametrize('lr', [0.1, 1])
147
147
  def test_adam_hyperparams(beta1, beta2, eps, amsgrad, lr):
148
- tz_fn = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
149
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
148
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
149
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
150
150
  _assert_identical_opts([tz_fn, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
151
151
 
152
152
  @pytest.mark.parametrize('centered', [True, False])
153
153
  def test_rmsprop(centered):
154
154
  torch_fn = lambda p: torch.optim.RMSprop(p, 1, centered=centered)
155
- tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(centered=centered, init='zeros'))
156
- tz_fn2 = lambda p: tz.Modular(
155
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(centered=centered, init='zeros'))
156
+ tz_fn2 = lambda p: tz.Optimizer(
157
157
  p,
158
158
  tz.m.Div([tz.m.CenteredSqrtEMASquared(0.99) if centered else tz.m.SqrtEMASquared(0.99), tz.m.Add(1e-8)]),
159
159
  )
160
- tz_fn3 = lambda p: tz.Modular(
160
+ tz_fn3 = lambda p: tz.Optimizer(
161
161
  p,
162
162
  tz.m.Div([tz.m.CenteredEMASquared(0.99) if centered else tz.m.EMASquared(0.99), tz.m.Sqrt(), tz.m.Add(1e-8)]),
163
163
  )
@@ -173,7 +173,7 @@ def test_rmsprop(centered):
173
173
  @pytest.mark.parametrize('centered', [True, False])
174
174
  @pytest.mark.parametrize('lr', [0.1, 1])
175
175
  def test_rmsprop_hyperparams(beta, eps, centered, lr):
176
- tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
176
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
177
177
  torch_fn = lambda p: torch.optim.RMSprop(p, lr, beta, eps=eps, centered=centered)
178
178
  _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=10)
179
179
 
@@ -185,7 +185,7 @@ def test_rmsprop_hyperparams(beta, eps, centered, lr):
185
185
  @pytest.mark.parametrize('ub', [50, 1.5])
186
186
  @pytest.mark.parametrize('lr', [0.1, 1])
187
187
  def test_rprop(nplus, nminus, lb, ub, lr):
188
- tz_fn = lambda p: tz.Modular(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
188
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
189
189
  torch_fn = lambda p: torch.optim.Rprop(p, lr, (nminus, nplus), (lb, ub))
190
190
  _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=30)
191
191
  _assert_identical_merge_closure(tz_fn, 'cpu', 30)
@@ -193,8 +193,8 @@ def test_rprop(nplus, nminus, lb, ub, lr):
193
193
 
194
194
  def test_adagrad():
195
195
  torch_fn = lambda p: torch.optim.Adagrad(p, 1)
196
- tz_fn = lambda p: tz.Modular(p, tz.m.Adagrad(), tz.m.LR(1))
197
- tz_fn2 = lambda p: tz.Modular(
196
+ tz_fn = lambda p: tz.Optimizer(p, tz.m.Adagrad(), tz.m.LR(1))
197
+ tz_fn2 = lambda p: tz.Optimizer(
198
198
  p,
199
199
  tz.m.Div([tz.m.Pow(2), tz.m.AccumulateSum(), tz.m.Sqrt(), tz.m.Add(1e-10)]),
200
200
  )
@@ -212,15 +212,15 @@ def test_adagrad():
212
212
  @pytest.mark.parametrize('lr', [0.1, 1])
213
213
  def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
214
214
  torch_fn = lambda p: torch.optim.Adagrad(p, lr, initial_accumulator_value=initial_accumulator_value, eps=eps)
215
- tz_fn1 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
216
- tz_fn2 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
215
+ tz_fn1 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
216
+ tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
217
217
  _assert_identical_opts([torch_fn, tz_fn1, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
218
218
 
219
219
 
220
220
  @pytest.mark.parametrize('tensorwise', [True, False])
221
221
  def test_graft(tensorwise):
222
- graft1 = lambda p: tz.Modular(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
- graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
222
+ graft1 = lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
+ graft2 = lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
224
224
  _assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
225
225
  for fn in [graft1, graft2]:
226
226
  if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)