torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,188 @@
1
+ import pytest
2
+ import torch
3
+ from torchzero.core import Objective
4
+ from torchzero.utils.tensorlist import TensorList
5
+
6
+ @torch.no_grad
7
+ def test_get_loss():
8
+
9
+ # ---------------------------- test that it works ---------------------------- #
10
+ params = [torch.tensor(2.0, requires_grad=True)]
11
+ evaluated = False
12
+
13
+ def closure_1(backward=True):
14
+ assert not backward, 'backward = True'
15
+
16
+ # ensure closure only evaluates once
17
+ nonlocal evaluated
18
+ assert evaluated is False, 'closure was evaluated twice'
19
+ evaluated = True
20
+
21
+ loss = params[0]**2
22
+ if backward:
23
+ params[0].grad = None
24
+ loss.backward()
25
+ else:
26
+ assert not loss.requires_grad, "loss requires grad with backward=False"
27
+ return loss
28
+
29
+ obj = Objective(params=params, closure=closure_1, model=None, current_step=0)
30
+
31
+ assert obj.loss is None, obj.loss
32
+
33
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
34
+ assert evaluated, evaluated
35
+ assert loss is obj.loss
36
+ assert obj.loss == 4.0
37
+ assert obj.loss_approx == 4.0
38
+ assert obj.grads is None, obj.grads
39
+
40
+ # reevaluate, which should just return already evaluated loss
41
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
42
+ assert obj.grads is None, obj.grads
43
+
44
+
45
+ # ----------------------- test that backward=True works ---------------------- #
46
+ params = [torch.tensor(3.0, requires_grad=True)]
47
+ evaluated = False
48
+
49
+ def closure_2(backward=True):
50
+ # ensure closure only evaluates once
51
+ nonlocal evaluated
52
+ assert evaluated is False, 'closure was evaluated twice'
53
+ evaluated = True
54
+
55
+ loss = params[0] * 2
56
+ if backward:
57
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
58
+ params[0].grad = None
59
+ loss.backward()
60
+ else:
61
+ assert not loss.requires_grad, "loss requires grad with backward=False"
62
+ return loss
63
+
64
+ obj = Objective(params=params, closure=closure_2, model=None, current_step=0)
65
+ assert obj.grads is None, obj.grads
66
+ assert (loss := obj.get_loss(backward=True)) == 6.0, loss
67
+ assert obj.grads is not None
68
+ assert obj.grads[0] == 2.0, obj.grads
69
+
70
+ # reevaluate, which should just return already evaluated loss
71
+ assert (loss := obj.get_loss(backward=True)) == 6.0, loss
72
+ assert obj.grads[0] == 2.0, obj.grads
73
+
74
+ # get grad, which should just return already evaluated grad
75
+ assert (grad := obj.get_grads())[0] == 2.0, grad
76
+ assert grad is obj.grads, grad
77
+
78
+ # get update, which should create and return cloned grad
79
+ assert obj.updates is None
80
+ assert (update := obj.get_updates())[0] == 2.0, update
81
+ assert update is obj.updates
82
+ assert update is not obj.grads
83
+ assert obj.grads is not None
84
+ assert update[0] == obj.grads[0]
85
+
86
+ @torch.no_grad
87
+ def test_get_grad():
88
+ params = [torch.tensor(2.0, requires_grad=True)]
89
+ evaluated = False
90
+
91
+ def closure(backward=True):
92
+ # ensure closure only evaluates once
93
+ nonlocal evaluated
94
+ assert evaluated is False, 'closure was evaluated twice'
95
+ evaluated = True
96
+
97
+ loss = params[0]**2
98
+ if backward:
99
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
100
+ params[0].grad = None
101
+ loss.backward()
102
+ else:
103
+ assert not loss.requires_grad, "loss requires grad with backward=False"
104
+ return loss
105
+
106
+ obj = Objective(params=params, closure=closure, model=None, current_step=0)
107
+ assert (grad := obj.get_grads())[0] == 4.0, grad
108
+ assert grad is obj.grads
109
+
110
+ assert obj.loss == 4.0
111
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
112
+ assert (loss := obj.get_loss(backward=True)) == 4.0, loss
113
+ assert obj.loss_approx == 4.0
114
+
115
+ assert obj.updates is None, obj.updates
116
+ assert (update := obj.get_updates())[0] == 4.0, update
117
+
118
+ @torch.no_grad
119
+ def test_get_update():
120
+ params = [torch.tensor(2.0, requires_grad=True)]
121
+ evaluated = False
122
+
123
+ def closure(backward=True):
124
+ # ensure closure only evaluates once
125
+ nonlocal evaluated
126
+ assert evaluated is False, 'closure was evaluated twice'
127
+ evaluated = True
128
+
129
+ loss = params[0]**2
130
+ if backward:
131
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
132
+ params[0].grad = None
133
+ loss.backward()
134
+ else:
135
+ assert not loss.requires_grad, "loss requires grad with backward=False"
136
+ return loss
137
+
138
+ obj = Objective(params=params, closure=closure, model=None, current_step=0)
139
+ assert obj.updates is None, obj.updates
140
+ assert (update := obj.get_updates())[0] == 4.0, update
141
+ assert update is obj.updates
142
+
143
+ assert (grad := obj.get_grads())[0] == 4.0, grad
144
+ assert grad is obj.grads
145
+ assert grad is not update
146
+
147
+ assert obj.loss == 4.0
148
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
149
+ assert (loss := obj.get_loss(backward=True)) == 4.0, loss
150
+ assert obj.loss_approx == 4.0
151
+
152
+ assert (update := obj.get_updates())[0] == 4.0, update
153
+
154
+
155
+ def _assert_objectives_are_same_(o1: Objective, o2: Objective, clone_update: bool):
156
+ for k,v in o1.__dict__.items():
157
+ if not k.startswith('__'):
158
+ # if k == 'post_step_hooks': continue
159
+ if k == 'storage': continue
160
+ elif k == 'updates' and clone_update:
161
+ if o1.updates is None or o2.updates is None:
162
+ assert o1.updates is None and o2.updates is None, f'`{k}` attribute is not the same, {o1.updates = }, {o2.updates = }'
163
+ else:
164
+ assert (TensorList(o1.updates) == TensorList(o2.updates)).global_all()
165
+ assert o1.updates is not o2.updates
166
+ elif k == 'params':
167
+ for p1, p2 in zip(o1.params, o2.params):
168
+ assert p1.untyped_storage() == p2.untyped_storage()
169
+ else:
170
+ assert getattr(o2, k) is v, f'`{k}` attribute is not the same, {getattr(o1, k) = }, {getattr(o2, k) = }'
171
+
172
+ def test_var_clone():
173
+ model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
174
+ def closure(backward): return 1
175
+ obj = Objective(params=list(model.parameters()), closure=closure, model=model, current_step=0)
176
+
177
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
178
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
179
+
180
+ obj.grads = TensorList(torch.randn(5))
181
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
182
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
183
+
184
+ obj.updates = TensorList(torch.randn(5) * 2)
185
+ obj.loss = torch.randn(1)
186
+ obj.loss_approx = obj.loss
187
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
188
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)