torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/core/module.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any, final, overload
6
+ from typing import Any, final, overload, Literal, cast
7
7
 
8
8
  import torch
9
9
 
@@ -14,7 +14,9 @@ from ..utils import (
14
14
  _make_param_groups,
15
15
  get_state_vals,
16
16
  )
17
+ from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
17
18
  from ..utils.python_tools import flatten
19
+ from ..utils.linalg.linear_operator import LinearOperator
18
20
 
19
21
 
20
22
  def _closure_backward(closure, params, retain_graph, create_graph):
@@ -32,11 +34,9 @@ def _closure_backward(closure, params, retain_graph, create_graph):
32
34
  # ----------------------------------- var ----------------------------------- #
33
35
  class Var:
34
36
  """
35
- Holds the state and context passed between optimizer modules during a step.
37
+ Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
38
+ Modules take in a ``Var`` object, modify and it is passed to the next module.
36
39
 
37
- This class acts as a mutable container for information relevant to the current
38
- optimization step, such as parameters, gradients, loss, and the computed update.
39
- Modules read from and write to this object to coordinate their actions.
40
40
  """
41
41
  def __init__(
42
42
  self,
@@ -44,6 +44,10 @@ class Var:
44
44
  closure: Callable | None,
45
45
  model: torch.nn.Module | None,
46
46
  current_step: int,
47
+ parent: "Var | None" = None,
48
+ modular: "Modular | None" = None,
49
+ loss: torch.Tensor | None = None,
50
+ storage: dict | None = None,
47
51
  ):
48
52
  self.params: list[torch.Tensor] = params
49
53
  """List of all parameters with requires_grad = True."""
@@ -55,19 +59,31 @@ class Var:
55
59
  """torch.nn.Module object of the model, None if it wasn't specified."""
56
60
 
57
61
  self.current_step: int = current_step
58
- """global current step, starts at 0"""
62
+ """global current step, starts at 0. This may not correspond to module current step,
63
+ for example a module may step every 10 global steps."""
64
+
65
+ self.parent: "Var | None" = parent
66
+ """parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
67
+ Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
68
+ e.g. when projecting."""
69
+
70
+ self.modular: "Modular" = cast(Modular, modular)
71
+ """Modular optimizer object that created this ``Var``."""
59
72
 
60
73
  self.update: list[torch.Tensor] | None = None
61
74
  """
62
- current update, at the end this is subtracted from model parameters unless it is None.
75
+ current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
63
76
 
64
77
  If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
78
+
79
+ At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
80
+ gradient will be used and calculated if needed.
65
81
  """
66
82
 
67
83
  self.grad: list[torch.Tensor] | None = None
68
- """gradient with current parameters. If closure is not None, this is set to None and can be calculated if needed."""
84
+ """gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
69
85
 
70
- self.loss: torch.Tensor | Any | None = None
86
+ self.loss: torch.Tensor | Any | None = loss
71
87
  """loss with current parameters."""
72
88
 
73
89
  self.loss_approx: torch.Tensor | Any | None = None
@@ -76,24 +92,28 @@ class Var:
76
92
 
77
93
  self.post_step_hooks: list[Callable[[Modular, Var]]] = []
78
94
  """list of functions to be called after optimizer step.
79
- The signature is:
80
95
 
81
- .. code:: py
96
+ This attribute should always be modified in-place (using ``append`` or ``extend``).
82
97
 
83
- def hook(optimizer: Modular, var: Vars): ...
98
+ The signature is:
84
99
 
100
+ ```python
101
+ def hook(optimizer: Modular, var: Vars): ...
102
+ ```
85
103
  """
86
104
 
87
105
  self.is_last: bool = False
88
106
  """
89
107
  Indicates that current module is either last or next-to-last before a learning rate module.
90
108
  This is always False if current module has children or is a child.
109
+ This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
91
110
  """
92
111
 
93
112
  self.nested_is_last: bool = False
94
113
  """
95
114
  Indicates that current module is either last or next-to-last before a learning rate module, for modules
96
- that have children.
115
+ that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
116
+ a child of a last module may also receive ``var.nested_is_last=True``.
97
117
  """
98
118
 
99
119
  self.last_module_lrs: list[float] | None = None
@@ -104,16 +124,30 @@ class Var:
104
124
  """
105
125
 
106
126
  self.stop: bool = False
107
- """if True, all following modules will be skipped."""
127
+ """if True, all following modules will be skipped.
128
+ If this module is a child, it only affects modules at the same level (in the same Chain)."""
108
129
 
109
130
  self.skip_update: bool = False
110
- """if True, the parameters will not be updated"""
131
+ """if True, the parameters will not be updated."""
111
132
 
112
- def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
113
- """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
114
- Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
133
+ # self.storage: dict = {}
134
+ # """Storage for any other data, such as hessian estimates, etc."""
115
135
 
136
+ self.attrs: dict = {}
137
+ """attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
138
+
139
+ if storage is None: storage = {}
140
+ self.storage: dict = storage
141
+ """additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
142
+
143
+ self.should_terminate: bool | None = None
144
+ """termination criteria, Modular.should_terminate is set to this after each step if not None"""
145
+
146
+ def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
147
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
148
+ Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
116
149
  if self.loss is None:
150
+
117
151
  if self.closure is None: raise RuntimeError("closure is None")
118
152
  if backward:
119
153
  with torch.enable_grad():
@@ -124,7 +158,10 @@ class Var:
124
158
  # initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
125
159
  # it is technically a more correct approach for when some parameters conditionally receive gradients
126
160
  # and in this case it shouldn't be slower.
127
- self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
161
+
162
+ # next time closure() is called, it will set grad to None.
163
+ # zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
164
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
128
165
  else:
129
166
  self.loss = self.loss_approx = self.closure(False)
130
167
 
@@ -139,11 +176,24 @@ class Var:
139
176
  closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
140
177
  )
141
178
  self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
179
+
180
+ # set parent grad
181
+ if self.parent is not None:
182
+ # the way projections/split work, they make a new closure which evaluates original
183
+ # closure and projects the gradient, and set it as their var.closure.
184
+ # then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
185
+ # and we set it to parent var here.
186
+ if self.parent.loss is None: self.parent.loss = self.loss
187
+ if self.parent.grad is None and backward:
188
+ if all(p.grad is None for p in self.parent.params):
189
+ warnings.warn("Parent grad is None after backward.")
190
+ self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
191
+
142
192
  return self.loss # type:ignore
143
193
 
144
194
  def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
145
195
  """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
146
- :code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
196
+ ``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
147
197
  if self.grad is None:
148
198
  if self.closure is None: raise RuntimeError("closure is None")
149
199
  self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
@@ -152,15 +202,21 @@ class Var:
152
202
  return self.grad
153
203
 
154
204
  def get_update(self) -> list[torch.Tensor]:
155
- """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
156
- Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
205
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
206
+ Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
157
207
  Do not call this at perturbed parameters."""
158
208
  if self.update is None: self.update = [g.clone() for g in self.get_grad()]
159
209
  return self.update
160
210
 
161
- def clone(self, clone_update: bool):
162
- """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
163
- copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
211
+ def clone(self, clone_update: bool, parent: "Var | None" = None):
212
+ """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
213
+
214
+ Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
215
+
216
+ Setting ``parent`` is only if clone's parameters are something different,
217
+ while clone's closure referes to the same objective but with a "view" on parameters.
218
+ """
219
+ copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
164
220
 
165
221
  if clone_update and self.update is not None:
166
222
  copy.update = [u.clone() for u in self.update]
@@ -170,10 +226,16 @@ class Var:
170
226
  copy.grad = self.grad
171
227
  copy.loss = self.loss
172
228
  copy.loss_approx = self.loss_approx
229
+ copy.closure = self.closure
173
230
  copy.post_step_hooks = self.post_step_hooks
174
231
  copy.stop = self.stop
175
232
  copy.skip_update = self.skip_update
176
233
 
234
+ copy.modular = self.modular
235
+ copy.attrs = self.attrs
236
+ copy.storage = self.storage
237
+ copy.should_terminate = self.should_terminate
238
+
177
239
  return copy
178
240
 
179
241
  def update_attrs_from_clone_(self, var: "Var"):
@@ -182,11 +244,16 @@ class Var:
182
244
  object. This propagates any newly computed loss or gradient values
183
245
  from the child's context back to the parent `Vars` if the parent
184
246
  didn't have them computed already.
247
+
248
+ Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
249
+ if the child updates them, the update will affect the parent too.
185
250
  """
186
251
  if self.loss is None: self.loss = var.loss
187
252
  if self.loss_approx is None: self.loss_approx = var.loss_approx
188
253
  if self.grad is None: self.grad = var.grad
189
254
 
255
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
256
+
190
257
  def zero_grad(self, set_to_none=True):
191
258
  if set_to_none:
192
259
  for p in self.params: p.grad = None
@@ -196,6 +263,7 @@ class Var:
196
263
 
197
264
  # endregion
198
265
 
266
+
199
267
  # region Module
200
268
  # ---------------------------------- module ---------------------------------- #
201
269
  class Module(ABC):
@@ -308,17 +376,16 @@ class Module(ABC):
308
376
 
309
377
  If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
310
378
 
311
- .. code:: py
312
-
313
- exp_avg = self.state_vals("exp_avg")
314
- # returns cls (by default TensorList)
315
-
316
- exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
317
- # returns list of cls
379
+ ```python
380
+ exp_avg = self.state_vals("exp_avg")
381
+ # returns cls (by default TensorList)
318
382
 
319
- exp_avg = self.state_vals(["exp_avg"])
320
- # always returns a list of cls, even if got a single key
383
+ exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
384
+ # returns list of cls
321
385
 
386
+ exp_avg = self.state_vals(["exp_avg"])
387
+ # always returns a list of cls, even if got a single key
388
+ ```
322
389
 
323
390
  Args:
324
391
  *keys (str):
@@ -358,6 +425,26 @@ class Module(ABC):
358
425
  # # if isinstance(params, Vars): params = params.params
359
426
  # return itemgetter(*keys)(self.settings[params[0]])
360
427
 
428
+ def clear_state_keys(self, *keys:str):
429
+ for s in self.state.values():
430
+ for k in keys:
431
+ if k in s: del s[k]
432
+
433
+ @overload
434
+ def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
435
+ @overload
436
+ def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
437
+ def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
438
+ if isinstance(keys, str):
439
+ for p,v in zip(params, values):
440
+ state = self.state[p]
441
+ state[keys] = v
442
+ return
443
+
444
+ for p, *p_v in zip(params, *values):
445
+ state = self.state[p]
446
+ for k,v in zip(keys, p_v): state[k] = v
447
+
361
448
  def state_dict(self):
362
449
  """state dict"""
363
450
  packed_state = {id(k):v for k,v in self.state.items()}
@@ -377,7 +464,8 @@ class Module(ABC):
377
464
  }
378
465
  return state_dict
379
466
 
380
- def load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
467
+ def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
468
+ """loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
381
469
  # load state
382
470
  state = state_dict['state']
383
471
  self.state.clear()
@@ -396,29 +484,159 @@ class Module(ABC):
396
484
 
397
485
  # children
398
486
  for k, v in state_dict['children']:
399
- if k in self.children: self.children[k].load_state_dict(v, id_to_tensor)
487
+ if k in self.children: self.children[k]._load_state_dict(v, id_to_tensor)
400
488
  else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
401
489
 
402
490
  # extra info
403
491
  self._extra_unpack(state_dict['extra'])
404
492
 
405
493
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
406
- @abstractmethod
407
494
  def step(self, var: Var) -> Var:
408
- """performs a step, returns new var but may update them in-place."""
495
+ """performs a step, returns new ``var`` but may update it in-place."""
496
+ self.update(var)
497
+ return self.apply(var)
498
+
499
+ def update(self, var:Var) -> Any:
500
+ """Updates the internal state of this module. This should not modify ``var.update``.
501
+
502
+ Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
503
+ such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
504
+ """
505
+
506
+ def apply(self, var: Var) -> Var:
507
+ """Applies this module to ``var.get_update()``.
508
+ This should not modify the internal state of this module if possible.
509
+
510
+ Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
511
+ such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
512
+ """
513
+ return self.step(var)
514
+
515
+ def get_H(self, var: Var) -> LinearOperator | None:
516
+ """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
517
+ The hessian approximation is assumed to be for all parameters concatenated to a vector."""
518
+ # if this method is not defined it searches in children
519
+ # this should be overwritten to return None if child params are different from this modules params
520
+ H = None
521
+ for k,v in self.children.items():
522
+ H_v = v.get_H(var)
523
+
524
+ if (H is not None) and (H_v is not None):
525
+ raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
526
+
527
+ if H_v is not None: H = H_v
528
+
529
+ return H
409
530
 
410
531
  def reset(self):
411
- """Resets the internal state of the module (e.g. momentum)."""
412
- # no complex logic is allowed there because this is overridden by many modules
413
- # where super().reset() shouldn't be called
532
+ """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
414
533
  self.state.clear()
415
534
  self.global_state.clear()
535
+ for c in self.children.values(): c.reset()
536
+
537
+ def reset_for_online(self):
538
+ """Resets buffers that depend on previous evaluation, such as previous gradient and loss,
539
+ which may become inaccurate due to mini-batching.
540
+
541
+ ``Online`` module calls ``reset_for_online``,
542
+ then it calls ``update`` with previous parameters,
543
+ then it calls ``update`` with current parameters,
544
+ and then ``apply``.
545
+ """
546
+ for c in self.children.values(): c.reset_for_online()
416
547
 
417
548
  def _extra_pack(self):
549
+ """extra information to store in state_dict of this optimizer.
550
+ Will be passed to ``_extra_unpack`` when loading the state_dict."""
418
551
  return {}
419
552
 
420
553
  def _extra_unpack(self, x):
421
- pass
554
+ """``_extra_pack`` return will be passed to this method when loading state_dict.
555
+ This method is called after loading the rest of the state dict"""
556
+
557
+
558
+
559
+ # ------------------------------ HELPER METHODS ------------------------------ #
560
+ @torch.no_grad
561
+ def Hvp(
562
+ self,
563
+ v: Sequence[torch.Tensor],
564
+ at_x0: bool,
565
+ var: Var,
566
+ rgrad: Sequence[torch.Tensor] | None,
567
+ hvp_method: Literal['autograd', 'forward', 'central'],
568
+ h: float,
569
+ normalize: bool,
570
+ retain_grad: bool,
571
+ ) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
572
+ """
573
+ Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
574
+ possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
575
+ Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
576
+
577
+ Single sample example:
578
+
579
+ ```python
580
+ Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
581
+ ```
582
+
583
+ Multiple samples example:
584
+
585
+ ```python
586
+ D = None
587
+ rgrad = None
588
+ for i in range(n_samples):
589
+ v = [torch.randn_like(p) for p in params]
590
+ Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
591
+
592
+ if D is None: D = Hvp
593
+ else: torch._foreach_add_(D, Hvp)
594
+
595
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
596
+ ```
597
+
598
+ Args:
599
+ v (Sequence[torch.Tensor]): vector in hessian-vector product
600
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
601
+ var (Var): Var
602
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
603
+ hvp_method (str): hvp method.
604
+ h (float): finite difference step size
605
+ normalize (bool): whether to normalize v for finite difference
606
+ retain_grad (bool): retain grad
607
+ """
608
+ # get grad
609
+ if rgrad is None and hvp_method in ('autograd', 'forward'):
610
+ if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
611
+ else:
612
+ if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
613
+ with torch.enable_grad():
614
+ loss = var.closure()
615
+ rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
616
+
617
+ if hvp_method == 'autograd':
618
+ assert rgrad is not None
619
+ Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
620
+
621
+ elif hvp_method == 'forward':
622
+ assert rgrad is not None
623
+ loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
624
+
625
+ elif hvp_method == 'central':
626
+ loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
627
+
628
+ else:
629
+ raise ValueError(hvp_method)
630
+
631
+ return Hvp, rgrad
632
+
633
+ def get_generator(self, device: torch.types.Device, seed: int | None):
634
+ if seed is None: return None
635
+
636
+ if 'generator' not in self.global_state:
637
+ self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
638
+
639
+ return self.global_state['generator']
422
640
 
423
641
  # endregion
424
642
 
@@ -440,6 +658,27 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
440
658
 
441
659
  # region Modular
442
660
  # ---------------------------------- Modular --------------------------------- #
661
+
662
+ class _EvalCounterClosure:
663
+ """keeps track of how many times closure has been evaluated, and sets closure return"""
664
+ __slots__ = ("modular", "closure")
665
+ def __init__(self, modular: "Modular", closure):
666
+ self.modular = modular
667
+ self.closure = closure
668
+
669
+ def __call__(self, *args, **kwargs):
670
+ if self.closure is None:
671
+ raise RuntimeError("One of the modules requires closure to be passed to the step method")
672
+
673
+ v = self.closure(*args, **kwargs)
674
+
675
+ # set closure return on 1st evaluation
676
+ if self.modular._closure_return is None:
677
+ self.modular._closure_return = v
678
+
679
+ self.modular.num_evaluations += 1
680
+ return v
681
+
443
682
  # have to inherit from Modular to support lr schedulers
444
683
  # although Accelerate doesn't work due to converting param_groups to a dict
445
684
  class Modular(torch.optim.Optimizer):
@@ -456,6 +695,7 @@ class Modular(torch.optim.Optimizer):
456
695
  param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
457
696
 
458
697
  def __init__(self, params: Params | torch.nn.Module, *modules: Module):
698
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
459
699
  self.model: torch.nn.Module | None = None
460
700
  """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
461
701
  if isinstance(params, torch.nn.Module):
@@ -489,14 +729,33 @@ class Modular(torch.optim.Optimizer):
489
729
  for m in self.unrolled_modules: defaults.update(m.defaults)
490
730
  super().__init__(param_groups, defaults=defaults)
491
731
 
492
- # note - this is what super init does:
732
+ # note - this is what super().__init__(param_groups, defaults=defaults) does:
493
733
 
494
734
  # self.defaults = defaults
495
735
  # for param_group in param_groups:
496
736
  # self.add_param_group(param_group)
497
737
 
738
+ # add_param_group adds a ChainMap where defaults are lowest priority,
739
+ # and entries specifed in param_groups or scheduler are higher priority.
740
+ # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
741
+ # in each module, settings passed to that module by calling set_param_groups are highest priority
742
+
498
743
  self.current_step = 0
499
- """The global step counter for the optimizer."""
744
+ """global step counter for the optimizer."""
745
+
746
+ self.num_evaluations = 0
747
+ """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
748
+
749
+ # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
750
+ # we want to return original loss so this attribute is used
751
+ self._closure_return = None
752
+ """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
753
+
754
+ self.attrs = {}
755
+ """custom attributes that can be set by modules, for example EMA of weights or best so far"""
756
+
757
+ self.should_terminate = False
758
+ """is set to True by termination criteria modules."""
500
759
 
501
760
  def add_param_group(self, param_group: dict[str, Any]):
502
761
  proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
@@ -542,10 +801,13 @@ class Modular(torch.optim.Optimizer):
542
801
 
543
802
  id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
544
803
  for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
545
- m.load_state_dict(sd, id_to_tensor)
804
+ m._load_state_dict(sd, id_to_tensor)
805
+
546
806
 
807
+ def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
808
+ # clear closure return from previous step
809
+ self._closure_return = None
547
810
 
548
- def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
549
811
  # propagate global per-parameter setting overrides
550
812
  for g in self.param_groups:
551
813
  settings = dict(g.maps[0]) # ignore defaults
@@ -558,15 +820,17 @@ class Modular(torch.optim.Optimizer):
558
820
 
559
821
  # create var
560
822
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
561
- var = Var(params=params, closure=closure, model=self.model, current_step=self.current_step)
823
+ var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
562
824
 
563
825
  # if closure is None, assume backward has been called and gather grads
564
826
  if closure is None:
565
827
  var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
828
+ self.num_evaluations += 1
566
829
 
830
+ n_modules = len(self.modules)
831
+ if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
567
832
  last_module = self.modules[-1]
568
833
  last_lr = last_module.defaults.get('lr', None)
569
- n_modules = len(self.modules)
570
834
 
571
835
  # step
572
836
  for i, module in enumerate(self.modules):
@@ -586,11 +850,17 @@ class Modular(torch.optim.Optimizer):
586
850
  with torch.no_grad():
587
851
  torch._foreach_sub_(params, var.get_update())
588
852
 
853
+ # update attributes
854
+ self.attrs.update(var.attrs)
855
+ if var.should_terminate is not None: self.should_terminate = var.should_terminate
856
+
857
+ # hooks
589
858
  for hook in var.post_step_hooks:
590
859
  hook(self, var)
591
860
 
592
861
  self.current_step += 1
593
- return var.loss if var.loss is not None else var.loss_approx
862
+ #return var.loss if var.loss is not None else var.loss_approx
863
+ return self._closure_return
594
864
 
595
865
  def __repr__(self):
596
866
  return f'Modular({", ".join(str(m) for m in self.modules)})'
@@ -606,6 +876,21 @@ class Chain(Module):
606
876
  for i, module in enumerate(flat_modules):
607
877
  self.set_child(f'module_{i}', module)
608
878
 
879
+ def update(self, var):
880
+ # note here that `update` and `apply` shouldn't be used directly
881
+ # as it will update all modules, and then apply all modules
882
+ # it is used in specific cases like Chain as trust region hessian module
883
+ for i in range(len(self.children)):
884
+ self.children[f'module_{i}'].update(var)
885
+ if var.stop: break
886
+ return var
887
+
888
+ def apply(self, var):
889
+ for i in range(len(self.children)):
890
+ var = self.children[f'module_{i}'].apply(var)
891
+ if var.stop: break
892
+ return var
893
+
609
894
  def step(self, var):
610
895
  for i in range(len(self.children)):
611
896
  var = self.children[f'module_{i}'].step(var)
@@ -616,7 +901,7 @@ class Chain(Module):
616
901
  s = self.__class__.__name__
617
902
  if self.children:
618
903
  if s == 'Chain': s = 'C' # to shorten it
619
- s = f'{s}({", ".join(str(m) for m in self.children.values())}'
904
+ s = f'{s}({", ".join(str(m) for m in self.children.values())})'
620
905
  return s
621
906
 
622
907
  def maybe_chain(*modules: Chainable) -> Module: