torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ """subspace optimizers to be used in a low rank eigenbasis
2
+
3
+ three opts support this - GGT and experimental AdaNystrom and Eigengrad
4
+
5
+ I could define repoject on a module but because most opts use per-parameter state that is complicated"""
6
+
7
+ import math
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, cast
10
+
11
+ import torch
12
+
13
+ from ...linalg import matrix_power_eigh, torch_linalg
14
+ from .lion import lion_
15
+
16
+ class LREOptimizerBase(ABC):
17
+ """Optimizer to run in a low rank eigenbasis.
18
+
19
+ notes:
20
+
21
+ 1. it shouldn't store any states in self, everything should be in state.
22
+ This is because this may be called on multiple parameters in a sequence
23
+
24
+ 2. apply is always called first, than reproject whenever eigenbasis gets updated
25
+
26
+ 3. L is variance in the eigenbasis.
27
+ """
28
+ @abstractmethod
29
+ def step(self, g: torch.Tensor, L: torch.Tensor, Q: torch.Tensor, state: dict) -> torch.Tensor:
30
+ ...
31
+
32
+ @abstractmethod
33
+ def reproject(self, L_old: torch.Tensor, Q_old: torch.Tensor,
34
+ L_new: torch.Tensor, Q_new: torch.Tensor, state: dict) -> None:
35
+ ...
36
+
37
+ class Whiten(LREOptimizerBase):
38
+ """This simply applies whitening and is equivalent to not running an optimizer in the eigenbasis"""
39
+ def step(self, g, L, Q, state): return (Q * L.rsqrt()) @ (Q.T @ g)
40
+ def reproject(self, L_old, Q_old, L_new, Q_new, state): pass
41
+
42
+ class EMA(LREOptimizerBase):
43
+ """Maintains exponential moving average of gradients in the low rank eigenbasis. Nesterov setting is experimental"""
44
+ def __init__(self, beta=0.9, nesterov:bool=False, cautious:bool=False, whiten:bool=True):
45
+ self.beta = beta
46
+ self.nesterov = nesterov
47
+ self.whiten = whiten
48
+ self.cautious = cautious
49
+
50
+ def step(self, g, L, Q, state):
51
+ g = Q.T @ g
52
+
53
+ if "exp_avg" not in state:
54
+ state["exp_avg"] = torch.zeros_like(g)
55
+
56
+ exp_avg = state["exp_avg"]
57
+ exp_avg.lerp_(g, 1-self.beta)
58
+
59
+ if self.nesterov:
60
+ dir = (g + exp_avg * self.beta) / (1 + self.beta)
61
+ else:
62
+ dir = exp_avg
63
+
64
+ if self.cautious:
65
+ mask = (g * dir) > 0
66
+ dir *= mask
67
+
68
+ if self.whiten: return (Q * L.rsqrt()) @ dir
69
+ return Q @ dir
70
+
71
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
72
+ if "exp_avg" not in state: return
73
+ C = Q_new.T @ Q_old
74
+ state["exp_avg"] = C @ state["exp_avg"]
75
+
76
+
77
+ def adam(g:torch.Tensor, state:dict, beta1, beta2, eps):
78
+
79
+ if "exp_avg" not in state:
80
+ state["exp_avg"] = torch.zeros_like(g)
81
+ state["exp_avg_sq"] = torch.zeros_like(g)
82
+ state["current_step"] = 1
83
+
84
+ exp_avg = state["exp_avg"]
85
+ exp_avg_sq = state["exp_avg_sq"]
86
+ current_step = state["current_step"]
87
+
88
+ exp_avg.lerp_(g, 1-beta1)
89
+ exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1-beta2)
90
+ denom = exp_avg_sq.sqrt().add_(eps)
91
+
92
+ bias_correction1 = 1.0 - (beta1 ** current_step)
93
+ bias_correction2 = 1.0 - (beta2 ** current_step)
94
+ alpha = math.sqrt(bias_correction2) / bias_correction1
95
+ state["current_step"] = current_step + 1
96
+
97
+ return (exp_avg * alpha) / denom
98
+
99
+ def _squared_reproject(C: torch.Tensor, sq: torch.Tensor, exact: bool):
100
+ if exact:
101
+ return (C @ sq.diag_embed() @ C.T).diagonal()
102
+
103
+ return C.square() @ sq
104
+
105
+ class Adam(LREOptimizerBase):
106
+ """Runs Adam in low rank eigenbasis."""
107
+ def __init__(self, beta1=0.9, beta2=0.95, cautious:bool=False, eps=1e-8, exact_reproject:bool=True):
108
+ self.beta1 = beta1
109
+ self.beta2 = beta2
110
+ self.eps = eps
111
+ self.cautious = cautious
112
+ self.exact_reproject = exact_reproject
113
+
114
+ def step(self, g, L, Q, state):
115
+ g = Q.T @ g
116
+
117
+ dir = adam(g, state, self.beta1, self.beta2, self.eps)
118
+
119
+ if self.cautious:
120
+ mask = (g * dir) > 0
121
+ dir *= mask
122
+
123
+ return Q @ dir
124
+
125
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
126
+ if "exp_avg" not in state: return
127
+ C = Q_new.T @ Q_old
128
+
129
+ state["exp_avg"] = C @ state["exp_avg"]
130
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
131
+
132
+
133
+ class FullMatrixAdam(LREOptimizerBase):
134
+ """Runs full-matrix Adam in low rank eigenbasis.
135
+ The preconditioner is updated whenever basis is updated"""
136
+ def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, matrix_power=-1/2, abs=True, cautious:bool=False):
137
+ self.beta1 = beta1
138
+ self.beta2 = beta2
139
+ self.eps = eps
140
+ self.matrix_power = matrix_power
141
+ self.abs = abs
142
+ self.cautious = cautious
143
+
144
+ def step(self, g, L, Q, state):
145
+ g = Q.T @ g
146
+
147
+ # initialize
148
+ if "exp_avg" not in state:
149
+ state["exp_avg"] = torch.zeros_like(g)
150
+ state["covariance"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
151
+ state["preconditioner"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
152
+ state["reprojected"] = True
153
+ state["current_step"] = 1
154
+
155
+ exp_avg = state["exp_avg"]
156
+ covariance = state["covariance"]
157
+ current_step = state["current_step"]
158
+
159
+ # update buffers
160
+ exp_avg.lerp_(g, 1-self.beta1)
161
+ covariance.lerp_(g.outer(g), weight=1-self.beta2)
162
+
163
+ # correct bias
164
+ bias_correction1 = 1.0 - (self.beta1 ** current_step)
165
+ exp_avg = exp_avg / bias_correction1
166
+
167
+ # after reprojecting update the preconditioner
168
+ if state["reprojected"]:
169
+ state["reprojected"] = False
170
+
171
+ bias_correction2 = 1.0 - (self.beta2 ** current_step)
172
+ covariance = covariance / bias_correction2
173
+
174
+ reg = torch.eye(covariance.size(0), device=covariance.device, dtype=covariance.dtype).mul_(self.eps)
175
+ covariance = covariance + reg
176
+
177
+ # compute matrix power
178
+ try:
179
+ state["preconditioner"] = matrix_power_eigh(covariance, self.matrix_power, abs=self.abs)
180
+
181
+ except torch.linalg.LinAlgError:
182
+
183
+ # fallback to diagonal
184
+ state["preconditioner"] = covariance.diagonal().rsqrt().diag_embed()
185
+
186
+ # compute the update
187
+ state["current_step"] = current_step + 1
188
+ preconditioner = state["preconditioner"]
189
+ dir = preconditioner @ exp_avg
190
+
191
+ if self.cautious:
192
+ mask = (g * dir) > 0
193
+ dir *= mask
194
+
195
+ return Q @ dir
196
+
197
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
198
+ if "exp_avg" not in state: return
199
+
200
+ state["reprojected"] = True
201
+
202
+ C = Q_new.T @ Q_old
203
+ state["exp_avg"] = C @ state["exp_avg"]
204
+ state["covariance"] = C @ state["covariance"] @ C.T
205
+
206
+ class Lion(LREOptimizerBase):
207
+ """Runs Lion in the low rank eigenbasis."""
208
+ def __init__(self, beta1=0.9, beta2=0.99, cautious:bool=False):
209
+ self.beta1 = beta1
210
+ self.beta2 = beta2
211
+ self.cautious = cautious
212
+
213
+ def step(self, g, L, Q, state):
214
+ g = Q.T @ g
215
+
216
+ if "exp_avg" not in state:
217
+ state["exp_avg"] = torch.zeros_like(g)
218
+
219
+ dir = cast(torch.Tensor, lion_(g, state["exp_avg"], beta1=self.beta1, beta2=self.beta2))
220
+
221
+ if self.cautious:
222
+ mask = (g * dir) > 0
223
+ dir *= mask
224
+
225
+ return Q @ dir
226
+
227
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
228
+ if "exp_avg" not in state: return
229
+ C = Q_new.T @ Q_old
230
+ state["exp_avg"] = C @ state["exp_avg"]
231
+
232
+
233
+ class Grams(LREOptimizerBase):
234
+ """Runs Grams in low rank eigenbasis."""
235
+ def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, exact_reproject=True):
236
+ self.beta1 = beta1
237
+ self.beta2 = beta2
238
+ self.eps = eps
239
+ self.exact_reproject = exact_reproject
240
+
241
+ def step(self, g, L, Q, state):
242
+ g = Q.T @ g
243
+ dir = adam(g, state, self.beta1, self.beta2, self.eps)
244
+ return Q @ dir.copysign(g)
245
+
246
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
247
+ if "exp_avg" not in state: return
248
+ C = Q_new.T @ Q_old
249
+
250
+ state["exp_avg"] = C @ state["exp_avg"]
251
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
252
+
253
+
254
+ class LaProp(LREOptimizerBase):
255
+ """Runs LaProp in low rank eigenbasis."""
256
+ def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, cautious:bool=False, exact_reproject=True):
257
+ self.beta1 = beta1
258
+ self.beta2 = beta2
259
+ self.eps = eps
260
+ self.cautious = cautious
261
+ self.exact_reproject = exact_reproject
262
+
263
+ def step(self, g, L, Q, state):
264
+ g = Q.T @ g
265
+
266
+ if "exp_avg" not in state:
267
+ state["exp_avg"] = torch.zeros_like(g)
268
+ state["exp_avg_sq"] = torch.zeros_like(g)
269
+ state["current_step"] = 1
270
+
271
+ exp_avg = state["exp_avg"]
272
+ exp_avg_sq = state["exp_avg_sq"]
273
+ current_step = state["current_step"]
274
+
275
+ # update second moments
276
+ exp_avg_sq.mul_(self.beta2).addcmul_(g, g, value=1-self.beta2)
277
+ bias_correction2 = 1.0 - (self.beta2 ** current_step)
278
+
279
+ # divide by bias corrected second moments
280
+ dir = g / (exp_avg_sq / bias_correction2).sqrt().add_(self.eps)
281
+
282
+ # update first moments and bias correct
283
+ exp_avg.lerp_(dir, 1-self.beta1)
284
+ bias_correction1 = 1.0 - (self.beta1 ** current_step)
285
+ dir = exp_avg / bias_correction1
286
+
287
+ if self.cautious:
288
+ mask = (g * dir) > 0
289
+ dir *= mask
290
+
291
+ state["current_step"] = current_step + 1
292
+ return Q @ dir
293
+
294
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
295
+ if "exp_avg" not in state: return
296
+ C = Q_new.T @ Q_old
297
+
298
+ state["exp_avg"] = C @ state["exp_avg"]
299
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from ...core import Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
5
 
6
6
 
@@ -20,7 +20,7 @@ def mars_correction_(
20
20
 
21
21
  return c
22
22
 
23
- class MARSCorrection(Transform):
23
+ class MARSCorrection(TensorTransform):
24
24
  """MARS variance reduction correction.
25
25
 
26
26
  Place any other momentum-based optimizer after this,
@@ -35,7 +35,7 @@ class MARSCorrection(Transform):
35
35
 
36
36
  Mars-AdamW
37
37
  ```python
38
- optimizer = tz.Modular(
38
+ optimizer = tz.Optimizer(
39
39
  model.parameters(),
40
40
  tz.m.MARSCorrection(beta=0.95),
41
41
  tz.m.Adam(beta1=0.95, beta2=0.99),
@@ -46,7 +46,7 @@ class MARSCorrection(Transform):
46
46
 
47
47
  Mars-Lion
48
48
  ```python
49
- optimizer = tz.Modular(
49
+ optimizer = tz.Optimizer(
50
50
  model.parameters(),
51
51
  tz.m.MARSCorrection(beta=0.9),
52
52
  tz.m.Lion(beta1=0.9),
@@ -61,11 +61,11 @@ class MARSCorrection(Transform):
61
61
  scaling: float = 0.025,
62
62
  max_norm: float | None = 1,
63
63
  ):
64
- defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
65
- super().__init__(defaults, uses_grad=False)
64
+ defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
65
+ super().__init__(defaults)
66
66
 
67
67
  @torch.no_grad
68
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
68
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
69
69
  prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
70
70
  beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
71
71
  max_norm = settings[0]['max_norm']
@@ -1,14 +1,13 @@
1
1
  from typing import Literal
2
- from collections.abc import Callable
2
+
3
3
  import torch
4
4
 
5
- from ...core import Module, apply_transform, Chainable
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
- from ..functional import initial_step_size
5
+ from ...core import Chainable, Transform, HVPMethod
6
+ from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
7
+ from ..opt_utils import initial_step_size
9
8
 
10
9
 
11
- class MatrixMomentum(Module):
10
+ class MatrixMomentum(Transform):
12
11
  """Second order momentum method.
13
12
 
14
13
  Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
@@ -23,17 +22,17 @@ class MatrixMomentum(Module):
23
22
  Args:
24
23
  mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
25
24
  hvp_method (str, optional):
26
- Determines how Hessian-vector products are evaluated.
27
-
28
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
29
- This requires creating a graph for the gradient.
30
- - ``"forward"``: Use a forward finite difference formula to
31
- approximate the HVP. This requires one extra gradient evaluation.
32
- - ``"central"``: Use a central finite difference formula for a
33
- more accurate HVP approximation. This requires two extra
34
- gradient evaluations.
35
- Defaults to "autograd".
36
- h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
25
+ Determines how hessian-vector products are computed.
26
+
27
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
28
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
29
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
30
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
31
+
32
+ Defaults to ``"autograd"``.
33
+ h (float, optional):
34
+ The step size for finite difference if ``hvp_method`` is
35
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
37
36
  hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
38
37
 
39
38
  Reference:
@@ -44,51 +43,45 @@ class MatrixMomentum(Module):
44
43
  self,
45
44
  lr:float,
46
45
  mu=0.1,
47
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
46
+ hvp_method: HVPMethod = "autograd",
48
47
  h: float = 1e-3,
49
48
  adaptive:bool = False,
50
49
  adapt_freq: int | None = None,
51
- hvp_tfm: Chainable | None = None,
50
+
51
+ inner: Chainable | None = None,
52
52
  ):
53
53
  defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
54
- super().__init__(defaults)
55
-
56
- if hvp_tfm is not None:
57
- self.set_child('hvp_tfm', hvp_tfm)
54
+ super().__init__(defaults, inner=inner)
58
55
 
59
56
  def reset_for_online(self):
60
57
  super().reset_for_online()
61
58
  self.clear_state_keys('p_prev')
62
59
 
63
60
  @torch.no_grad
64
- def update(self, var):
65
- assert var.closure is not None
66
- p = TensorList(var.params)
67
- p_prev = self.get_state(p, 'p_prev', init=var.params)
61
+ def update_states(self, objective, states, settings):
62
+ step = self.increment_counter("step", 0)
63
+ p = TensorList(objective.params)
64
+ p_prev = unpack_states(states, p, 'p_prev', init=p)
68
65
 
69
- hvp_method = self.defaults['hvp_method']
70
- h = self.defaults['h']
71
- step = self.global_state.get("step", 0)
72
- self.global_state["step"] = step + 1
66
+ fs = settings[0]
67
+ hvp_method = fs['hvp_method']
68
+ h = fs['h']
73
69
 
74
70
  if step > 0:
75
71
  s = p - p_prev
76
72
 
77
- Hs, _ = var.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_graph=False)
73
+ Hs, _ = objective.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, retain_graph=False)
78
74
  Hs = [t.detach() for t in Hs]
79
75
 
80
- if 'hvp_tfm' in self.children:
81
- Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
82
-
83
76
  self.store(p, ("Hs", "s"), (Hs, s))
84
77
 
85
78
  # -------------------------------- adaptive mu ------------------------------- #
86
- if self.defaults["adaptive"]:
87
- g = TensorList(var.get_grad())
79
+ if fs["adaptive"]:
80
+ g = TensorList(objective.get_grads())
88
81
 
89
- if self.defaults["adapt_freq"] is None:
82
+ if fs["adapt_freq"] is None:
90
83
  # ---------------------------- deterministic case ---------------------------- #
91
- g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
84
+ g_prev = unpack_states(states, p, "g_prev", cls=TensorList)
92
85
  y = g - g_prev
93
86
  g_prev.copy_(g)
94
87
  denom = y.global_vector_norm()
@@ -101,14 +94,14 @@ class MatrixMomentum(Module):
101
94
 
102
95
  # we start on 1nd step, and want to adapt when we start, so use (step - 1)
103
96
  if (step - 1) % adapt_freq == 0:
104
- assert var.closure is not None
105
- params = TensorList(var.params)
97
+ assert objective.closure is not None
98
+ params = TensorList(objective.params)
106
99
  p_cur = params.clone()
107
100
 
108
101
  # move to previous params and evaluate p_prev with current mini-batch
109
- params.copy_(self.get_state(var.params, 'p_prev'))
102
+ params.copy_(unpack_states(states, p, 'p_prev'))
110
103
  with torch.enable_grad():
111
- var.closure()
104
+ objective.closure()
112
105
  g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
113
106
  y = g - g_prev
114
107
 
@@ -119,12 +112,12 @@ class MatrixMomentum(Module):
119
112
  denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
120
113
  self.global_state["mu_mul"] = s.global_vector_norm() / denom
121
114
 
122
- torch._foreach_copy_(p_prev, var.params)
115
+ torch._foreach_copy_(p_prev, objective.params)
123
116
 
124
117
  @torch.no_grad
125
- def apply(self, var):
126
- update = TensorList(var.get_update())
127
- lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)
118
+ def apply_states(self, objective, states, settings):
119
+ update = TensorList(objective.get_updates())
120
+ lr, mu = unpack_dicts(settings, "lr", 'mu', cls=NumberList)
128
121
 
129
122
  if "mu_mul" in self.global_state:
130
123
  mu = mu * self.global_state["mu_mul"]
@@ -133,14 +126,17 @@ class MatrixMomentum(Module):
133
126
  # p_prev is not available so make a small step
134
127
  step = self.global_state["step"]
135
128
  if step == 1:
136
- if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
129
+ if self.defaults["adaptive"]:
130
+ # initialize
131
+ unpack_states(states, objective.params, "g_prev", init=objective.get_grads())
132
+
137
133
  update.mul_(lr) # separate so that initial_step_size can clip correctly
138
134
  update.mul_(initial_step_size(update, 1e-7))
139
- return var
135
+ return objective
140
136
 
141
137
  # -------------------------- matrix momentum update -------------------------- #
142
- s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)
138
+ s, Hs = unpack_states(states, objective.params, 's', 'Hs', cls=TensorList)
143
139
 
144
140
  update.mul_(lr).sub_(s).add_(Hs*mu)
145
- var.update = update
146
- return var
141
+ objective.updates = update
142
+ return objective